diff -r a9a38b274b8a Lib/ssl.py --- a/Lib/ssl.py Sat Jun 19 21:58:37 2010 +0200 +++ b/Lib/ssl.py Sun Jun 20 22:50:47 2010 +0200 @@ -59,7 +59,7 @@ import textwrap import _ssl # if we can't import it, let the error propagate from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION -from _ssl import _SSLContext, SSLError +from _ssl import _SSLContext, _SSLSocket, SSLError from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED from _ssl import (PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1) @@ -107,49 +107,35 @@ class SSLContext(_SSLContext): _context=self) -class SSLSocket(socket): +class SSLSocket(_SSLSocket): """This class implements a subtype of socket.socket that wraps the underlying OS socket in an SSL context when necessary, and provides read and write methods over that channel.""" - def __init__(self, sock=None, keyfile=None, certfile=None, - server_side=False, cert_reqs=CERT_NONE, - ssl_version=PROTOCOL_SSLv23, ca_certs=None, - do_handshake_on_connect=True, - family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, - suppress_ragged_eofs=True, ciphers=None, - _context=None): + def __new__(cls, sock=None, keyfile=None, certfile=None, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_SSLv23, ca_certs=None, + do_handshake_on_connect=True, + family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, + suppress_ragged_eofs=True, ciphers=None, + _context=None): if _context: - self.context = _context + context = _context else: if certfile and not keyfile: keyfile = certfile - self.context = SSLContext(ssl_version) - self.context.verify_mode = cert_reqs + context = SSLContext(ssl_version) + context.verify_mode = cert_reqs if ca_certs: - self.context.load_verify_locations(ca_certs) + context.load_verify_locations(ca_certs) if certfile: - self.context.load_cert_chain(certfile, keyfile) + context.load_cert_chain(certfile, keyfile) if ciphers: - self.context.set_ciphers(ciphers) - self.keyfile = keyfile - self.certfile = certfile - self.cert_reqs = cert_reqs - self.ssl_version = ssl_version - self.ca_certs = ca_certs - self.ciphers = ciphers + context.set_ciphers(ciphers) - self.do_handshake_on_connect = do_handshake_on_connect - self.suppress_ragged_eofs = suppress_ragged_eofs connected = False if sock is not None: - socket.__init__(self, - family=sock.family, - type=sock.type, - proto=sock.proto, - fileno=_dup(sock.fileno())) - self.settimeout(sock.gettimeout()) # see if it's connected try: sock.getpeername() @@ -158,20 +144,38 @@ class SSLSocket(socket): raise else: connected = True - sock.close() elif fileno is not None: - socket.__init__(self, fileno=fileno) + sock = socket(fileno=fileno) else: - socket.__init__(self, family=family, type=type, proto=proto) + sock = socket(family=family, type=type, proto=proto) + self = _SSLSocket.__new__(cls, context, sock, server_side) + self.context = context + self._sock = sock + + self.keyfile = keyfile + self.certfile = certfile + self.cert_reqs = cert_reqs + self.ssl_version = ssl_version + self.ca_certs = ca_certs + self.ciphers = ciphers + + self.do_handshake_on_connect = do_handshake_on_connect + self.suppress_ragged_eofs = suppress_ragged_eofs + + # Delegation to socket methods self._closed = False - self._sslobj = None + self._io_refs = 0 + for attr in ( + "settimeout", "gettimeout", "setblocking", + "fileno", "getpeername", "getsockname"): + setattr(self, attr, getattr(sock, attr)) + if connected: # create the SSL object try: - self._sslobj = self.context._wrap_socket(self, server_side) if do_handshake_on_connect: - timeout = self.gettimeout() + timeout = self._sock.gettimeout() if timeout == 0.0: # non-blocking raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets") @@ -181,6 +185,20 @@ class SSLSocket(socket): self.close() raise x + return self + + # Re-use the socket implementation for these methods. It would probably + # be better to factor out this code in a common mixin. + makefile = socket.makefile + close = socket.close + _decref_socketios = socket._decref_socketios + + # For backwards compatibility + @property + def _sslobj(self): + # XXX raise DeprecationWarning + return self + def dup(self): raise NotImplemented("Can't dup() %s instances" % self.__class__.__name__) @@ -192,13 +210,12 @@ class SSLSocket(socket): def read(self, len=0, buffer=None): """Read up to LEN bytes and return them. Return zero-length string on EOF.""" - self._checkClosed() try: if buffer: - v = self._sslobj.read(buffer, len) + v = _SSLSocket.read(self, buffer, len) else: - v = self._sslobj.read(len or 1024) + v = _SSLSocket.read(self, len or 1024) return v except SSLError as x: if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: @@ -212,9 +229,8 @@ class SSLSocket(socket): def write(self, data): """Write DATA to the underlying SSL channel. Returns number of bytes of DATA actually transmitted.""" - self._checkClosed() - return self._sslobj.write(data) + return _SSLSocket.write(self, data) def getpeercert(self, binary_form=False): """Returns a formatted version of the data in the @@ -223,67 +239,52 @@ class SSLSocket(socket): certificate was provided, but not validated.""" self._checkClosed() - return self._sslobj.peer_certificate(binary_form) + return _SSLSocket.peer_certificate(self, binary_form) def cipher(self): self._checkClosed() - if not self._sslobj: - return None - else: - return self._sslobj.cipher() + return _SSLSocket.cipher(self) def send(self, data, flags=0): self._checkClosed() - if self._sslobj: - if flags != 0: - raise ValueError( - "non-zero flags not allowed in calls to send() on %s" % - self.__class__) - while True: - try: - v = self._sslobj.write(data) - except SSLError as x: - if x.args[0] == SSL_ERROR_WANT_READ: - return 0 - elif x.args[0] == SSL_ERROR_WANT_WRITE: - return 0 - else: - raise + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to send() on %s" % + self.__class__) + while True: + try: + v = _SSLSocket.write(self, data) + except SSLError as x: + if x.args[0] == SSL_ERROR_WANT_READ: + return 0 + elif x.args[0] == SSL_ERROR_WANT_WRITE: + return 0 else: - return v - else: - return socket.send(self, data, flags) + raise + else: + return v def sendto(self, data, addr, flags=0): self._checkClosed() - if self._sslobj: - raise ValueError("sendto not allowed on instances of %s" % - self.__class__) - else: - return socket.sendto(self, data, addr, flags) + raise ValueError("sendto not allowed on instances of %s" % + self.__class__) def sendall(self, data, flags=0): self._checkClosed() - if self._sslobj: - amount = len(data) - count = 0 - while (count < amount): - v = self.send(data[count:]) - count += v - return amount - else: - return socket.sendall(self, data, flags) + amount = len(data) + count = 0 + while (count < amount): + v = self.send(data[count:]) + count += v + return amount def recv(self, buflen=1024, flags=0): self._checkClosed() - if self._sslobj: - if flags != 0: - raise ValueError( - "non-zero flags not allowed in calls to recv() on %s" % - self.__class__) - return self.read(buflen) - else: - return socket.recv(self, buflen, flags) + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to recv() on %s" % + self.__class__) + return self.read(buflen) def recv_into(self, buffer, nbytes=None, flags=0): self._checkClosed() @@ -291,83 +292,53 @@ class SSLSocket(socket): nbytes = len(buffer) elif nbytes is None: nbytes = 1024 - if self._sslobj: - if flags != 0: - raise ValueError( - "non-zero flags not allowed in calls to recv_into() on %s" % - self.__class__) - return self.read(nbytes, buffer) - else: - return socket.recv_into(self, buffer, nbytes, flags) + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to recv_into() on %s" % + self.__class__) + return self.read(nbytes, buffer) def recvfrom(self, addr, buflen=1024, flags=0): self._checkClosed() - if self._sslobj: - raise ValueError("recvfrom not allowed on instances of %s" % - self.__class__) - else: - return socket.recvfrom(self, addr, buflen, flags) + raise ValueError("recvfrom not allowed on instances of %s" % + self.__class__) def recvfrom_into(self, buffer, nbytes=None, flags=0): self._checkClosed() - if self._sslobj: - raise ValueError("recvfrom_into not allowed on instances of %s" % - self.__class__) - else: - return socket.recvfrom_into(self, buffer, nbytes, flags) + raise ValueError("recvfrom_into not allowed on instances of %s" % + self.__class__) def pending(self): self._checkClosed() - if self._sslobj: - return self._sslobj.pending() - else: - return 0 + return _SSLSocket.pending(self) def shutdown(self, how): self._checkClosed() - self._sslobj = None - socket.shutdown(self, how) + self._sock.shutdown(how) def unwrap(self): - if self._sslobj: - s = self._sslobj.shutdown() - self._sslobj = None - return s - else: - raise ValueError("No SSL wrapper around " + str(self)) + _SSLSocket.shutdown(self) + return self._sock def _real_close(self): - self._sslobj = None - # self._closed = True - socket._real_close(self) + self._sock._real_close() def do_handshake(self, block=False): """Perform a TLS/SSL handshake.""" - - timeout = self.gettimeout() + timeout = self._sock.gettimeout() try: if timeout == 0.0 and block: - self.settimeout(None) - self._sslobj.do_handshake() + self._sock.settimeout(None) + _SSLSocket.do_handshake(self) finally: - self.settimeout(timeout) + self._sock.settimeout(timeout) def connect(self, addr): """Connects to remote ADDR, and then wraps the connection in an SSL channel.""" - - # Here we assume that the socket is client-side, and not - # connected at the time of the call. We connect it, then wrap it. - if self._sslobj: - raise ValueError("attempt to connect already-connected SSLSocket!") - socket.connect(self, addr) - self._sslobj = self.context._wrap_socket(self, False) - try: - if self.do_handshake_on_connect: - self.do_handshake() - except: - self._sslobj = None - raise + self._sock.connect(addr) + if self.do_handshake_on_connect: + self.do_handshake() def accept(self): """Accepts a new connection from a remote client, and returns @@ -386,10 +357,6 @@ class SSLSocket(socket): self.do_handshake_on_connect), addr) - def __del__(self): - # sys.stderr.write("__del__ on %s\n" % repr(self)) - self._real_close() - def wrap_socket(sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, diff -r a9a38b274b8a Lib/test/test_ssl.py --- a/Lib/test/test_ssl.py Sat Jun 19 21:58:37 2010 +0200 +++ b/Lib/test/test_ssl.py Sun Jun 20 22:50:47 2010 +0200 @@ -810,7 +810,7 @@ else: try: asyncore.loop(1) except: - pass + traceback.print_exc() def stop(self): self.active = False diff -r a9a38b274b8a Modules/_ssl.c --- a/Modules/_ssl.c Sat Jun 19 21:58:37 2010 +0200 +++ b/Modules/_ssl.c Sun Jun 20 22:50:47 2010 +0200 @@ -274,13 +274,27 @@ _setSSLError (char *errstr, int errcode, return NULL; } -static PySSLSocket * -newPySSLSocket(SSL_CTX *ctx, PySocketSockObject *sock, - enum py_ssl_server_or_client socket_type) +/* + * SSL sockets + */ + +static PyObject * +PySSL_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { PySSLSocket *self; + char *kwlist[] = {"context", "sock", "server_side", NULL}; + PySSLContext *context; + PySocketSockObject *sock; + int server_side = 0; - self = PyObject_New(PySSLSocket, &PySSLSocket_Type); + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!O!i:_SSLSocket", kwlist, + &PySSLContext_Type, &context, + PySocketModule.Sock_Type, &sock, + &server_side)) + return NULL; + + assert(type != NULL && type->tp_alloc != NULL); + self = (PySSLSocket *) type->tp_alloc(type, 0); if (self == NULL) return NULL; @@ -293,7 +307,7 @@ newPySSLSocket(SSL_CTX *ctx, PySocketSoc ERR_clear_error(); PySSL_BEGIN_ALLOW_THREADS - self->ssl = SSL_new(ctx); + self->ssl = SSL_new(context->ctx); PySSL_END_ALLOW_THREADS SSL_set_fd(self->ssl, sock->sock_fd); #ifdef SSL_MODE_AUTO_RETRY @@ -309,18 +323,17 @@ newPySSLSocket(SSL_CTX *ctx, PySocketSoc } PySSL_BEGIN_ALLOW_THREADS - if (socket_type == PY_SSL_CLIENT) + if (server_side) + SSL_set_accept_state(self->ssl); + else SSL_set_connect_state(self->ssl); - else - SSL_set_accept_state(self->ssl); PySSL_END_ALLOW_THREADS self->Socket = PyWeakref_NewRef((PyObject *) sock, NULL); - return self; + + return (PyObject *)self; } -/* SSL object methods */ - static PyObject *PySSL_SSLdo_handshake(PySSLSocket *self) { int ret; @@ -963,7 +976,7 @@ static void PySSL_dealloc(PySSLSocket *s if (self->ssl) SSL_free(self->ssl); Py_XDECREF(self->Socket); - PyObject_Del(self); + Py_TYPE(self)->tp_free(self); } /* If the socket has a timeout, do a select()/poll() on the socket. @@ -1388,7 +1401,7 @@ static PyTypeObject PySSLSocket_Type = { 0, /*tp_getattro*/ 0, /*tp_setattro*/ 0, /*tp_as_buffer*/ - Py_TPFLAGS_DEFAULT, /*tp_flags*/ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ 0, /*tp_doc*/ 0, /*tp_traverse*/ 0, /*tp_clear*/ @@ -1397,6 +1410,16 @@ static PyTypeObject PySSLSocket_Type = { 0, /*tp_iter*/ 0, /*tp_iternext*/ PySSLMethods, /*tp_methods*/ + 0, /*tp_members*/ + 0, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + 0, /*tp_dictoffset*/ + 0, /*tp_init*/ + 0, /*tp_alloc*/ + PySSL_new /*tp_new*/ }; @@ -1656,21 +1679,6 @@ load_verify_locations(PySSLContext *self Py_RETURN_NONE; } -static PyObject * -context_wrap_socket(PySSLContext *self, PyObject *args, PyObject *kwds) -{ - char *kwlist[] = {"sock", "server_side", NULL}; - PySocketSockObject *sock; - int server_side = 0; - - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!i:_wrap_socket", kwlist, - PySocketModule.Sock_Type, - &sock, &server_side)) - return NULL; - - return (PyObject *) newPySSLSocket(self->ctx, sock, server_side); -} - static PyGetSetDef context_getsetlist[] = { {"options", (getter) get_options, (setter) set_options, NULL}, @@ -1680,8 +1688,6 @@ static PyGetSetDef context_getsetlist[] }; static struct PyMethodDef context_methods[] = { - {"_wrap_socket", (PyCFunction) context_wrap_socket, - METH_VARARGS | METH_KEYWORDS, NULL}, {"set_ciphers", (PyCFunction) set_ciphers, METH_VARARGS, NULL}, {"load_cert_chain", (PyCFunction) load_cert_chain,