diff -r 4696f8d9bbea -r 3b5d0f72809e Lib/ssl.py --- a/Lib/ssl.py Fri Jul 11 00:37:31 2014 -0400 +++ b/Lib/ssl.py Wed Jul 09 11:22:08 2014 +0200 @@ -97,7 +97,7 @@ import _ssl # if we can't import it, let the error propagate from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION -from _ssl import _SSLContext +from _ssl import _SSLContext, _MemoryBIO from _ssl import ( SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, SSLSyscallError, SSLEOFError, diff -r 4696f8d9bbea -r 3b5d0f72809e Lib/test/test_ssl.py --- a/Lib/test/test_ssl.py Fri Jul 11 00:37:31 2014 -0400 +++ b/Lib/test/test_ssl.py Wed Jul 09 11:22:08 2014 +0200 @@ -1220,6 +1220,69 @@ self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ) +class MemoryBIOTests(unittest.TestCase): + + def test_read_write(self): + bio = ssl._MemoryBIO() + bio.write(b'foo') + self.assertEqual(bio.read(), b'foo') + self.assertEqual(bio.read(), b'') + bio.write(b'foo') + bio.write(b'bar') + self.assertEqual(bio.read(), b'foobar') + self.assertEqual(bio.read(), b'') + bio.write(b'baz') + self.assertEqual(bio.read(2), b'ba') + self.assertEqual(bio.read(1), b'z') + self.assertEqual(bio.read(1), b'') + + def test_eof(self): + bio = ssl._MemoryBIO() + self.assertFalse(bio.eof) + self.assertEqual(bio.read(), b'') + self.assertFalse(bio.eof) + bio.write(b'foo') + self.assertFalse(bio.eof) + bio.write_eof() + self.assertFalse(bio.eof) + self.assertEqual(bio.read(2), b'fo') + self.assertFalse(bio.eof) + self.assertEqual(bio.read(1), b'o') + self.assertTrue(bio.eof) + self.assertEqual(bio.read(), b'') + self.assertTrue(bio.eof) + + def test_pending(self): + bio = ssl._MemoryBIO() + self.assertEqual(bio.pending, 0) + bio.write(b'foo') + self.assertEqual(bio.pending, 3) + for i in range(3): + bio.read(1) + self.assertEqual(bio.pending, 3-i-1) + for i in range(3): + bio.write(b'x') + self.assertEqual(bio.pending, i+1) + bio.read() + self.assertEqual(bio.pending, 0) + + def test_buffer_types(self): + bio = ssl._MemoryBIO() + bio.write(b'foo') + self.assertEqual(bio.read(), b'foo') + bio.write(bytearray(b'bar')) + self.assertEqual(bio.read(), b'bar') + bio.write(memoryview(b'baz')) + self.assertEqual(bio.read(), b'baz') + + def test_error_types(self): + bio = ssl._MemoryBIO() + self.assertRaises(TypeError, bio.write, 'foo') + self.assertRaises(TypeError, bio.write, None) + self.assertRaises(TypeError, bio.write, True) + self.assertRaises(TypeError, bio.write, 1) + + class NetworkedTests(unittest.TestCase): def test_connect(self): @@ -1550,6 +1613,105 @@ self.assertIs(ss.context, ctx2) self.assertIs(ss._sslobj.context, ctx2) + +class NetworkedBIOTests(unittest.TestCase): + + def retry_eintr(self, func, *args): + # Retry func(*args) until it does not raise EINTR anymore. + while True: + try: + ret = func(*args) + except OSError as e: + if e.errno != errno.EINTR: + raise + continue + break + return ret + + def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs): + # A simple IO loop. Call func(*args) depending on the error we get + # (WANT_READ or WANT_WRITE) move data between the socket and the BIOs. + timeout = kwargs.get('timeout', 10) + count = 0 + while True: + errno = None + count += 1 + try: + ret = func(*args) + except ssl.SSLError as e: + # Note that we get a spurious -1/SSL_ERROR_SYSCALL for + # non-blocking IO. The SSL_shutdown manpage hints at this. + # It *should* be safe to just ignore SYS_ERROR_SYSCALL because + # with a Memory BIO there's no syscalls (for IO at least). + if e.errno not in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + raise + errno = e.errno + # Get any data from the outgoing BIO irrespective of any error, and + # send it to the socket. + buf = outgoing.read() + offset = 0 + while offset < len(buf): + self.retry_eintr(select.select, [], [sock], [], timeout) + offset += self.retry_eintr(sock.send, buf[offset:]) + # If there's no error, we're done. For WANT_READ, we need to get + # data from the socket and ptu it in the incoming BIO. + if errno is None: + break + elif errno == ssl.SSL_ERROR_WANT_READ: + self.retry_eintr(select.select, [sock], [], [], timeout) + buf = self.retry_eintr(sock.recv, 32768) + if buf: + incoming.write(buf) + else: + incoming.write_eof() + if support.verbose: + sys.stdout.write("Needed %d calls to complete %s().\n" + % (count, func.__name__)) + return ret + + def test_handshake(self): + with support.transient_internet("svn.python.org"): + sock = socket.socket(socket.AF_INET) + sock.connect(("svn.python.org", 443)) + incoming = ssl._MemoryBIO() + outgoing = ssl._MemoryBIO() + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.load_verify_locations(SVN_PYTHON_ORG_ROOT_CERT) + sslobj = ctx._wrap_bio(incoming, outgoing, 0) + self.assertIsNone(sslobj.cipher()) + self.assertRaises(ValueError, sslobj.peer_certificate) + if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: + self.assertIsNone(sslobj.tls_unique_cb()) + self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) + self.assertTrue(sslobj.cipher()) + self.assertTrue(sslobj.peer_certificate()) + if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: + self.assertTrue(sslobj.tls_unique_cb()) + self.ssl_io_loop(sock, incoming, outgoing, sslobj.shutdown) + self.assertRaises(ssl.SSLError, sslobj.write, b'foo') + sock.close() + + def test_read_write_data(self): + with support.transient_internet("svn.python.org"): + sock = socket.socket(socket.AF_INET) + sock.connect(("svn.python.org", 443)) + incoming = ssl._MemoryBIO() + outgoing = ssl._MemoryBIO() + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_NONE + sslobj = ctx._wrap_bio(incoming, outgoing, 0) + self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) + req = b'GET / HTTP/1.0\r\n\r\n' + self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req) + buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024) + self.assertEqual(buf[:5], b'HTTP/') + self.ssl_io_loop(sock, incoming, outgoing, sslobj.shutdown) + sock.close() + + try: import threading except ImportError: @@ -3007,10 +3169,11 @@ if not os.path.exists(filename): raise support.TestFailed("Can't read certificate file %r" % filename) - tests = [ContextTests, BasicSocketTests, SSLErrorTests] + tests = [ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests] if support.is_resource_enabled('network'): tests.append(NetworkedTests) + tests.append(NetworkedBIOTests) if _have_threads: thread_info = support.threading_setup() diff -r 4696f8d9bbea -r 3b5d0f72809e Modules/_ssl.c --- a/Modules/_ssl.c Fri Jul 11 00:37:31 2014 -0400 +++ b/Modules/_ssl.c Wed Jul 09 11:22:08 2014 +0200 @@ -64,6 +64,7 @@ #include "openssl/ssl.h" #include "openssl/err.h" #include "openssl/rand.h" +#include "openssl/bio.h" /* SSL error object */ static PyObject *PySSLErrorObject; @@ -228,8 +229,15 @@ enum py_ssl_server_or_client socket_type; } PySSLSocket; +typedef struct { + PyObject_HEAD + BIO *bio; + int eof_written; +} PySSLMemoryBIO; + static PyTypeObject PySSLContext_Type; static PyTypeObject PySSLSocket_Type; +static PyTypeObject PySSLMemoryBIO_Type; static PyObject *PySSL_SSLwrite(PySSLSocket *self, PyObject *args); static PyObject *PySSL_SSLread(PySSLSocket *self, PyObject *args); @@ -240,6 +248,7 @@ #define PySSLContext_Check(v) (Py_TYPE(v) == &PySSLContext_Type) #define PySSLSocket_Check(v) (Py_TYPE(v) == &PySSLSocket_Type) +#define PySSLMemoryBIO_Check(v) (Py_TYPE(v) == &PySSLMemoryBIO_Type) typedef enum { SOCKET_IS_NONBLOCKING, @@ -254,6 +263,9 @@ #define ERRSTR1(x,y,z) (x ":" y ": " z) #define ERRSTR(x) ERRSTR1("_ssl.c", Py_STRINGIFY(__LINE__), x) +/* Get the socket from a PySSLSocket, if it has one */ +#define GET_SOCKET(obj) ((obj)->Socket ? \ + (PySocketSockObject *) PyWeakref_GetObject((obj)->Socket) : NULL) /* * SSL errors. @@ -417,13 +429,12 @@ case SSL_ERROR_SYSCALL: { if (e == 0) { - PySocketSockObject *s - = (PySocketSockObject *) PyWeakref_GetObject(obj->Socket); + PySocketSockObject *s = GET_SOCKET(obj); if (ret == 0 || (((PyObject *)s) == Py_None)) { p = PY_SSL_ERROR_EOF; type = PySSLEOFErrorObject; errstr = "EOF occurred in violation of protocol"; - } else if (ret == -1) { + } else if (s && ret == -1) { /* underlying BIO reported an I/O error */ Py_INCREF(s); ERR_clear_error(); @@ -477,7 +488,8 @@ static PySSLSocket * newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock, enum py_ssl_server_or_client socket_type, - char *server_hostname) + char *server_hostname, + PySSLMemoryBIO *inbio, PySSLMemoryBIO *outbio) { PySSLSocket *self; SSL_CTX *ctx = sslctx->ctx; @@ -502,8 +514,17 @@ PySSL_BEGIN_ALLOW_THREADS self->ssl = SSL_new(ctx); PySSL_END_ALLOW_THREADS - SSL_set_app_data(self->ssl,self); - SSL_set_fd(self->ssl, Py_SAFE_DOWNCAST(sock->sock_fd, SOCKET_T, int)); + SSL_set_app_data(self->ssl, self); + if (sock) { + SSL_set_fd(self->ssl, Py_SAFE_DOWNCAST(sock->sock_fd, SOCKET_T, int)); + } else { + /* BIOs are reference counted and SSL_set_bio borrows our reference. + * To prevent a double free in memory_bio_dealloc() we need to take an + * extra reference here. */ + CRYPTO_add(&inbio->bio->references, 1, CRYPTO_LOCK_BIO); + CRYPTO_add(&outbio->bio->references, 1, CRYPTO_LOCK_BIO); + SSL_set_bio(self->ssl, inbio->bio, outbio->bio); + } mode = SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER; #ifdef SSL_MODE_AUTO_RETRY mode |= SSL_MODE_AUTO_RETRY; @@ -518,7 +539,7 @@ /* If the socket is in non-blocking mode or timeout mode, set the BIO * to non-blocking mode (blocking is the default) */ - if (sock->sock_timeout >= 0.0) { + if (sock && sock->sock_timeout >= 0.0) { BIO_set_nbio(SSL_get_rbio(self->ssl), 1); BIO_set_nbio(SSL_get_wbio(self->ssl), 1); } @@ -531,10 +552,12 @@ PySSL_END_ALLOW_THREADS self->socket_type = socket_type; - self->Socket = PyWeakref_NewRef((PyObject *) sock, NULL); - if (self->Socket == NULL) { - Py_DECREF(self); - return NULL; + if (sock != NULL) { + self->Socket = PyWeakref_NewRef((PyObject *) sock, NULL); + if (self->Socket == NULL) { + Py_DECREF(self); + return NULL; + } } return self; } @@ -546,20 +569,21 @@ int ret; int err; int sockstate, nonblocking; - PySocketSockObject *sock - = (PySocketSockObject *) PyWeakref_GetObject(self->Socket); - - if (((PyObject*)sock) == Py_None) { - _setSSLError("Underlying socket connection gone", - PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); - return NULL; + PySocketSockObject *sock = GET_SOCKET(self); + + if (sock) { + if (((PyObject*)sock) == Py_None) { + _setSSLError("Underlying socket connection gone", + PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); + return NULL; + } + Py_INCREF(sock); + + /* just in case the blocking state of the socket has been changed */ + nonblocking = (sock->sock_timeout >= 0.0); + BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); + BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); } - Py_INCREF(sock); - - /* just in case the blocking state of the socket has been changed */ - nonblocking = (sock->sock_timeout >= 0.0); - BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); - BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); /* Actually negotiate SSL connection */ /* XXX If SSL_do_handshake() returns 0, it's also a failure. */ @@ -593,7 +617,7 @@ break; } } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); - Py_DECREF(sock); + Py_XDECREF(sock); if (ret < 1) return PySSL_SetError(self, ret, __FILE__, __LINE__); @@ -608,7 +632,7 @@ return Py_None; error: - Py_DECREF(sock); + Py_XDECREF(sock); return NULL; } @@ -1496,10 +1520,10 @@ int rc; /* Nothing to do unless we're in timeout mode (not non-blocking) */ - if (s->sock_timeout < 0.0) + if ((s == NULL) || (s->sock_timeout == 0.0)) + return SOCKET_IS_NONBLOCKING; + else if (s->sock_timeout < 0.0) return SOCKET_IS_BLOCKING; - else if (s->sock_timeout == 0.0) - return SOCKET_IS_NONBLOCKING; /* Guard against closed socket */ if (s->sock_fd < 0) @@ -1560,18 +1584,19 @@ int sockstate; int err; int nonblocking; - PySocketSockObject *sock - = (PySocketSockObject *) PyWeakref_GetObject(self->Socket); - - if (((PyObject*)sock) == Py_None) { - _setSSLError("Underlying socket connection gone", - PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); - return NULL; + PySocketSockObject *sock = GET_SOCKET(self); + + if (sock != NULL) { + if (((PyObject*)sock) == Py_None) { + _setSSLError("Underlying socket connection gone", + PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); + return NULL; + } + Py_INCREF(sock); } - Py_INCREF(sock); if (!PyArg_ParseTuple(args, "y*:write", &buf)) { - Py_DECREF(sock); + Py_XDECREF(sock); return NULL; } @@ -1581,10 +1606,12 @@ goto error; } - /* just in case the blocking state of the socket has been changed */ - nonblocking = (sock->sock_timeout >= 0.0); - BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); - BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); + if (sock != NULL) { + /* just in case the blocking state of the socket has been changed */ + nonblocking = (sock->sock_timeout >= 0.0); + BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); + BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); + } sockstate = check_socket_and_wait_for_timeout(sock, 1); if (sockstate == SOCKET_HAS_TIMED_OUT) { @@ -1628,7 +1655,7 @@ } } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); - Py_DECREF(sock); + Py_XDECREF(sock); PyBuffer_Release(&buf); if (len > 0) return PyLong_FromLong(len); @@ -1636,7 +1663,7 @@ return PySSL_SetError(self, len, __FILE__, __LINE__); error: - Py_DECREF(sock); + Py_XDECREF(sock); PyBuffer_Release(&buf); return NULL; } @@ -1676,15 +1703,16 @@ int sockstate; int err; int nonblocking; - PySocketSockObject *sock - = (PySocketSockObject *) PyWeakref_GetObject(self->Socket); - - if (((PyObject*)sock) == Py_None) { - _setSSLError("Underlying socket connection gone", - PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); - return NULL; + PySocketSockObject *sock = GET_SOCKET(self); + + if (sock != NULL) { + if (((PyObject*)sock) == Py_None) { + _setSSLError("Underlying socket connection gone", + PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); + return NULL; + } + Py_INCREF(sock); } - Py_INCREF(sock); buf.obj = NULL; buf.buf = NULL; @@ -1710,10 +1738,12 @@ } } - /* just in case the blocking state of the socket has been changed */ - nonblocking = (sock->sock_timeout >= 0.0); - BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); - BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); + if (sock != NULL) { + /* just in case the blocking state of the socket has been changed */ + nonblocking = (sock->sock_timeout >= 0.0); + BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); + BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); + } /* first check if there are bytes ready to be read */ PySSL_BEGIN_ALLOW_THREADS @@ -1769,7 +1799,7 @@ } done: - Py_DECREF(sock); + Py_XDECREF(sock); if (!buf_passed) { _PyBytes_Resize(&dest, count); return dest; @@ -1780,7 +1810,7 @@ } error: - Py_DECREF(sock); + Py_XDECREF(sock); if (!buf_passed) Py_XDECREF(dest); else @@ -1797,21 +1827,22 @@ { int err, ssl_err, sockstate, nonblocking; int zeros = 0; - PySocketSockObject *sock - = (PySocketSockObject *) PyWeakref_GetObject(self->Socket); - - /* Guard against closed socket */ - if ((((PyObject*)sock) == Py_None) || (sock->sock_fd < 0)) { - _setSSLError("Underlying socket connection gone", - PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); - return NULL; + PySocketSockObject *sock = GET_SOCKET(self); + + if (sock != NULL) { + /* Guard against closed socket */ + if ((((PyObject*)sock) == Py_None) || (sock->sock_fd < 0)) { + _setSSLError("Underlying socket connection gone", + PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); + return NULL; + } + Py_INCREF(sock); + + /* Just in case the blocking state of the socket has been changed */ + nonblocking = (sock->sock_timeout >= 0.0); + BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); + BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); } - Py_INCREF(sock); - - /* Just in case the blocking state of the socket has been changed */ - nonblocking = (sock->sock_timeout >= 0.0); - BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); - BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); while (1) { PySSL_BEGIN_ALLOW_THREADS @@ -1869,15 +1900,19 @@ } if (err < 0) { - Py_DECREF(sock); + Py_XDECREF(sock); return PySSL_SetError(self, err, __FILE__, __LINE__); } - else + if (sock) { /* It's already INCREF'ed */ return (PyObject *) sock; + } else { + Py_INCREF(Py_None); + return Py_None; + } error: - Py_DECREF(sock); + Py_XDECREF(sock); return NULL; } @@ -2812,14 +2847,49 @@ #endif } - res = (PyObject *) newPySSLSocket(self, sock, server_side, - hostname); + res = (PyObject *) newPySSLSocket(self, sock, server_side, hostname, + NULL, NULL); if (hostname != NULL) PyMem_Free(hostname); return res; } static PyObject * +context_wrap_bio(PySSLContext *self, PyObject *args, PyObject *kwds) +{ + char *kwlist[] = {"incoming", "outgoing", "server_side", + "server_hostname", NULL}; + int server_side; + char *hostname = NULL; + PyObject *hostname_obj = Py_None, *res; + PySSLMemoryBIO *incoming, *outgoing; + + /* server_hostname is either None (or absent), or to be encoded + using the idna encoding. */ + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!O!i|O:_wrap_bio", kwlist, + &PySSLMemoryBIO_Type, &incoming, + &PySSLMemoryBIO_Type, &outgoing, + &server_side, &hostname_obj)) + return NULL; + if (hostname_obj != Py_None) { +#if HAVE_SNI + if (!PyArg_Parse(hostname_obj, "et", "idna", &hostname)) + return NULL; +#else + PyErr_SetString(PyExc_ValueError, "server_hostname is not supported " + "by your OpenSSL library"); + return NULL; +#endif + } + + res = (PyObject *) newPySSLSocket(self, NULL, server_side, hostname, + incoming, outgoing); + + PyMem_Free(hostname); + return res; +} + +static PyObject * session_stats(PySSLContext *self, PyObject *unused) { int r; @@ -2925,10 +2995,20 @@ ssl = SSL_get_app_data(s); assert(PySSLSocket_Check(ssl)); - ssl_socket = PyWeakref_GetObject(ssl->Socket); - Py_INCREF(ssl_socket); - if (ssl_socket == Py_None) { - goto error; + + /* Pass a PySSLSocket instance when using memory BIOs, but an ssl.SSLSocket + * when using sockets. Note that the latter is not a subclass of the + * former, but both do have a "context" property. THis supports the common + * use case of setting this property in the servername callback. */ + ssl_socket = (PyObject *) GET_SOCKET(ssl); + if (ssl_socket == NULL) { + Py_INCREF(ssl); + ssl_socket = (PyObject *) ssl; + } else { + Py_INCREF(ssl_socket); + if (ssl_socket == Py_None) { + goto error; + } } if (servername == NULL) { @@ -3158,6 +3238,8 @@ static struct PyMethodDef context_methods[] = { {"_wrap_socket", (PyCFunction) context_wrap_socket, METH_VARARGS | METH_KEYWORDS, NULL}, + {"_wrap_bio", (PyCFunction) context_wrap_bio, + METH_VARARGS | METH_KEYWORDS, NULL}, {"set_ciphers", (PyCFunction) set_ciphers, METH_VARARGS, NULL}, {"_set_npn_protocols", (PyCFunction) _set_npn_protocols, @@ -3227,6 +3309,216 @@ }; +/* + * _MemoryBIO objects + */ + +static PyObject * +memory_bio_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + char *kwlist[] = {NULL}; + BIO *bio; + PySSLMemoryBIO *self; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, ":_MemoryBIO", kwlist)) + return NULL; + + bio = BIO_new(BIO_s_mem()); + if (bio == NULL) { + PyErr_SetString(PySSLErrorObject, + "failed to allocate BIO"); + return NULL; + } + BIO_set_retry_read(bio); + BIO_set_mem_eof_return(bio, -1); + + assert(type != NULL && type->tp_alloc != NULL); + self = (PySSLMemoryBIO *) type->tp_alloc(type, 0); + if (self == NULL) { + BIO_free(bio); + return NULL; + } + self->bio = bio; + self->eof_written = 0; + + return (PyObject *) self; +} + +static void +memory_bio_dealloc(PySSLMemoryBIO *self) +{ + BIO_free(self->bio); + Py_TYPE(self)->tp_free(self); +} + +static PyObject * +memory_bio_get_pending(PySSLMemoryBIO *self, void *c) +{ + return PyLong_FromLong(BIO_ctrl_pending(self->bio)); +} + +PyDoc_STRVAR(PySSL_memory_bio_pending_doc, +"The number of bytes pending in the memory BIO."); + +static PyObject * +memory_bio_get_eof(PySSLMemoryBIO *self, void *c) +{ + return PyBool_FromLong((BIO_ctrl_pending(self->bio) == 0) + && self->eof_written); +} + +PyDoc_STRVAR(PySSL_memory_bio_eof_doc, +"Whether the memory BIO is at EOF."); + +static PyObject * +memory_bio_read(PySSLMemoryBIO *self, PyObject *args) +{ + int len = -1, avail, nbytes; + PyObject *result; + + if (!PyArg_ParseTuple(args, "|i:read", &len)) + return NULL; + + avail = BIO_ctrl_pending(self->bio); + if ((len < 0) || (len > avail)) + len = avail; + + result = PyBytes_FromStringAndSize(NULL, len); + if ((result == NULL) || (len == 0)) + return result; + + nbytes = BIO_read(self->bio, PyBytes_AS_STRING(result), len); + /* There should never be any short reads but check anyway. */ + if ((nbytes < len) && (_PyBytes_Resize(&result, len) < 0)) { + Py_DECREF(result); + return NULL; + } + + return result; +} + +PyDoc_STRVAR(PySSL_memory_bio_read_doc, +"read([len]) -> bytes\n\ +\n\ +Read up to len bytes from the memory BIO.\n\ +\n\ +If len is not specified, read the entire buffer.\n\ +If the return value is an empty bytes instance, this means either\n\ +EOF or that no data is available. Use the \"eof\" property to\n\ +distinguish between the two."); + +static PyObject * +memory_bio_write(PySSLMemoryBIO *self, PyObject *args) +{ + Py_buffer buf; + int nbytes; + + if (!PyArg_ParseTuple(args, "y*:write", &buf)) + return NULL; + + if (buf.len > INT_MAX) { + PyErr_Format(PyExc_OverflowError, + "string longer than %d bytes", INT_MAX); + return NULL; + } + + if (self->eof_written) { + PyErr_SetString(PySSLErrorObject, + "cannot write() after write_eof()"); + return NULL; + } + + nbytes = BIO_write(self->bio, buf.buf, buf.len); + if (nbytes < 0) { + _setSSLError(NULL, 0, __FILE__, __LINE__); + return NULL; + } + + return PyLong_FromLong(nbytes); +} + +PyDoc_STRVAR(PySSL_memory_bio_write_doc, +"write(b) -> len\n\ +\n\ +Writes the bytes b into the memory BIO. Returns the number\n\ +of bytes written."); + +static PyObject * +memory_bio_write_eof(PySSLMemoryBIO *self, PyObject *args) +{ + self->eof_written = 1; + BIO_clear_retry_flags(self->bio); + BIO_set_mem_eof_return(self->bio, 0); + + Py_INCREF(Py_None); + return Py_None; +} + +PyDoc_STRVAR(PySSL_memory_bio_write_eof_doc, +"write_eof()\n\ +\n\ +Write an EOF marker to the memory BIO.\n\ +When all data has been read, the \"eof\" property will be True."); + +static PyGetSetDef memory_bio_getsetlist[] = { + {"pending", (getter) memory_bio_get_pending, NULL, + PySSL_memory_bio_pending_doc}, + {"eof", (getter) memory_bio_get_eof, NULL, + PySSL_memory_bio_eof_doc}, + {NULL}, /* sentinel */ +}; + +static struct PyMethodDef memory_bio_methods[] = { + {"read", (PyCFunction) memory_bio_read, + METH_VARARGS, PySSL_memory_bio_read_doc}, + {"write", (PyCFunction) memory_bio_write, + METH_VARARGS, PySSL_memory_bio_write_doc}, + {"write_eof", (PyCFunction) memory_bio_write_eof, + METH_NOARGS, PySSL_memory_bio_write_eof_doc}, + {NULL, NULL} /* sentinel */ +}; + +static PyTypeObject PySSLMemoryBIO_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "_ssl._MemoryBIO", /*tp_name*/ + sizeof(PySSLMemoryBIO), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + (destructor)memory_bio_dealloc, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_reserved*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash*/ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ + 0, /*tp_doc*/ + 0, /*tp_traverse*/ + 0, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + memory_bio_methods, /*tp_methods*/ + 0, /*tp_members*/ + memory_bio_getsetlist, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + 0, /*tp_dictoffset*/ + 0, /*tp_init*/ + 0, /*tp_alloc*/ + memory_bio_new, /*tp_new*/ +}; + #ifdef HAVE_OPENSSL_RAND @@ -3914,6 +4206,8 @@ return NULL; if (PyType_Ready(&PySSLSocket_Type) < 0) return NULL; + if (PyType_Ready(&PySSLMemoryBIO_Type) < 0) + return NULL; m = PyModule_Create(&_sslmodule); if (m == NULL) @@ -3977,6 +4271,9 @@ if (PyDict_SetItemString(d, "_SSLSocket", (PyObject *)&PySSLSocket_Type) != 0) return NULL; + if (PyDict_SetItemString(d, "_MemoryBIO", + (PyObject *)&PySSLMemoryBIO_Type) != 0) + return NULL; PyModule_AddIntConstant(m, "SSL_ERROR_ZERO_RETURN", PY_SSL_ERROR_ZERO_RETURN); PyModule_AddIntConstant(m, "SSL_ERROR_WANT_READ",