diff -r 12b2efa10da1 Lib/ssl.py --- a/Lib/ssl.py Fri Apr 05 10:14:28 2013 +0300 +++ b/Lib/ssl.py Tue Apr 09 13:25:41 2013 +0900 @@ -91,6 +91,7 @@ from socket import getnameinfo as _getnameinfo import base64 # for DER-to-PEM translation import errno +import select # Disable weak or insecure ciphers by default # (OpenSSL's default setting is 'DEFAULT:!aNULL:!eNULL') @@ -123,6 +124,9 @@ if certfile and not keyfile: keyfile = certfile + + self._shutdown_required = False + # see if it's connected try: socket.getpeername(self) @@ -281,28 +285,46 @@ def unwrap(self): if self._sslobj: - s = self._sslobj.shutdown() + s = self._ssl_shutdown() self._sslobj = None return s else: raise ValueError("No SSL wrapper around " + str(self)) def shutdown(self, how): + self._ssl_shutdown() self._sslobj = None socket.shutdown(self, how) def close(self): if self._makefile_refs < 1: + self._ssl_shutdown() self._sslobj = None socket.close(self) else: self._makefile_refs -= 1 + def _ssl_shutdown(self): + if self._sslobj and self._shutdown_required: + while True: + try: + s = self._sslobj.shutdown() + self._shutdown_required = False + return s + except SSLError, e: + if e.args[0] == SSL_ERROR_WANT_READ: + select.select([self._sock], [], []) + elif e.args[0] == SSL_ERROR_WANT_WRITE: + select.select([], [self._sock], []) + else: + raise + def do_handshake(self): """Perform a TLS/SSL handshake.""" self._sslobj.do_handshake() + self._shutdown_required = True def _real_connect(self, addr, return_errno): # Here we assume that the socket is client-side, and not diff -r 12b2efa10da1 Lib/test/test_ssl.py --- a/Lib/test/test_ssl.py Fri Apr 05 10:14:28 2013 +0300 +++ b/Lib/test/test_ssl.py Tue Apr 09 13:25:41 2013 +0900 @@ -17,6 +17,7 @@ import functools import platform +from SocketServer import TCPServer,StreamRequestHandler from BaseHTTPServer import HTTPServer from SimpleHTTPServer import SimpleHTTPRequestHandler @@ -794,6 +795,41 @@ def stop(self): self.server.shutdown() + class ThreadedTCPStreamServer(threading.Thread): + class TCPStreamHandler(StreamRequestHandler): + def handle(self): + self.wfile.write("123") + + def __init__(self, certfile, ssl_version): + self.flag = None + self.server = TCPServer((HOST, 0), self.TCPStreamHandler) + self.server.socket = ssl.wrap_socket( + self.server.socket, certfile=certfile, server_side=True, + ssl_version=ssl_version, suppress_ragged_eofs=False) + self.port = self.server.server_address[1] + threading.Thread.__init__(self) + self.daemon = True + + def __enter__(self): + self.start(threading.Event()) + self.flag.wait() + return self + + def __exit__(self, *args): + self.stop() + self.join() + + def start(self, flag=None): + self.flag = flag + threading.Thread.start(self) + + def run(self): + if self.flag: + self.flag.set() + self.server.serve_forever() + + def stop(self): + self.server.shutdown() def bad_cert_test(certfile): """ @@ -1358,6 +1394,21 @@ sock.close() self.assertIn("no shared cipher", str(server.conn_errors[0])) + def test_ssl_clean_shutdown(self): + for ssl_proto_ver in (ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1): + with ThreadedTCPStreamServer(CERTFILE, ssl_proto_ver) as server: + sock = socket.socket() + try: + s = ssl.wrap_socket(sock, ssl_version=ssl_proto_ver, suppress_ragged_eofs=False) + s.connect((HOST, server.port)) + s.write("a") + s.recv(3) + try: + self.assertEqual(s.recv(1), '') + except ssl.SSLError, e: + self.assertEqual(e.args[0], ssl.SSL_ERROR_ZERO_RETURN) + finally: + sock.close() def test_main(verbose=False): global CERTFILE, SVN_PYTHON_ORG_ROOT_CERT, NOKIACERT diff -r 12b2efa10da1 Modules/_ssl.c --- a/Modules/_ssl.c Fri Apr 05 10:14:28 2013 +0300 +++ b/Modules/_ssl.c Tue Apr 09 13:25:41 2013 +0900 @@ -1423,6 +1423,10 @@ sockstate = check_socket_and_wait_for_timeout(self->Socket, 0); else if (ssl_err == SSL_ERROR_WANT_WRITE) sockstate = check_socket_and_wait_for_timeout(self->Socket, 1); + else if (ssl_err == SSL_ERROR_SYSCALL && self->shutdown_seen_zero){ + err = 0; /* handle misleading error as described in SSL_shutdown man */ + break; + } else break; if (sockstate == SOCKET_HAS_TIMED_OUT) {