diff -r 53343d095f78 Lib/test/test_ssl.py --- a/Lib/test/test_ssl.py Wed Mar 23 00:54:48 2016 +0100 +++ b/Lib/test/test_ssl.py Wed Mar 23 01:45:32 2016 +0100 @@ -570,14 +570,14 @@ class BasicSocketTests(unittest.TestCase def test_unknown_channel_binding(self): # should raise ValueError for unknown type s = socket.socket(socket.AF_INET) - s.bind(('127.0.0.1', 0)) - s.listen() - c = socket.socket(socket.AF_INET) - c.connect(s.getsockname()) - with ssl.wrap_socket(c, do_handshake_on_connect=False) as ss: - with self.assertRaises(ValueError): - ss.get_channel_binding("unknown-type") - s.close() + with s: + s.bind(('127.0.0.1', 0)) + s.listen() + c = socket.socket(socket.AF_INET) + c.connect(s.getsockname()) + with ssl.wrap_socket(c, do_handshake_on_connect=False) as ss: + with self.assertRaises(ValueError): + ss.get_channel_binding("unknown-type") @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, "'tls-unique' channel binding not available") @@ -1374,28 +1374,24 @@ class NetworkedTests(unittest.TestCase): with support.transient_internet(REMOTE_HOST): s = ssl.wrap_socket(socket.socket(socket.AF_INET), cert_reqs=ssl.CERT_NONE) - try: + with s: s.connect((REMOTE_HOST, 443)) self.assertEqual({}, s.getpeercert()) - finally: - s.close() # this should fail because we have no verification certs s = ssl.wrap_socket(socket.socket(socket.AF_INET), cert_reqs=ssl.CERT_REQUIRED) - self.assertRaisesRegex(ssl.SSLError, "certificate verify failed", - s.connect, (REMOTE_HOST, 443)) - s.close() + with s: + self.assertRaisesRegex(ssl.SSLError, "certificate verify failed", + s.connect, (REMOTE_HOST, 443)) # this should succeed because we specify the root cert s = ssl.wrap_socket(socket.socket(socket.AF_INET), cert_reqs=ssl.CERT_REQUIRED, ca_certs=REMOTE_ROOT_CERT) - try: + with s: s.connect((REMOTE_HOST, 443)) self.assertTrue(s.getpeercert()) - finally: - s.close() def test_connect_ex(self): # Issue #11326: check connect_ex() implementation @@ -1403,11 +1399,9 @@ class NetworkedTests(unittest.TestCase): s = ssl.wrap_socket(socket.socket(socket.AF_INET), cert_reqs=ssl.CERT_REQUIRED, ca_certs=REMOTE_ROOT_CERT) - try: + with s: self.assertEqual(0, s.connect_ex((REMOTE_HOST, 443))) self.assertTrue(s.getpeercert()) - finally: - s.close() def test_non_blocking_connect_ex(self): # Issue #11326: non-blocking connect_ex() should allow handshake @@ -1417,7 +1411,7 @@ class NetworkedTests(unittest.TestCase): cert_reqs=ssl.CERT_REQUIRED, ca_certs=REMOTE_ROOT_CERT, do_handshake_on_connect=False) - try: + with s: s.setblocking(False) rc = s.connect_ex((REMOTE_HOST, 443)) # EWOULDBLOCK under Windows, EINPROGRESS elsewhere @@ -1435,8 +1429,6 @@ class NetworkedTests(unittest.TestCase): select.select([], [s], [], 5.0) # SSL established self.assertTrue(s.getpeercert()) - finally: - s.close() def test_timeout_connect_ex(self): # Issue #12065: on a timeout, connect_ex() should return the original @@ -1446,21 +1438,19 @@ class NetworkedTests(unittest.TestCase): cert_reqs=ssl.CERT_REQUIRED, ca_certs=REMOTE_ROOT_CERT, do_handshake_on_connect=False) - try: + with s: s.settimeout(0.0000001) rc = s.connect_ex((REMOTE_HOST, 443)) if rc == 0: self.skipTest("REMOTE_HOST responded too quickly") self.assertIn(rc, (errno.EAGAIN, errno.EWOULDBLOCK)) - finally: - s.close() def test_connect_ex_error(self): with support.transient_internet(REMOTE_HOST): s = ssl.wrap_socket(socket.socket(socket.AF_INET), cert_reqs=ssl.CERT_REQUIRED, ca_certs=REMOTE_ROOT_CERT) - try: + with s: rc = s.connect_ex((REMOTE_HOST, 444)) # Issue #19919: Windows machines or VMs hosted on Windows # machines sometimes return EWOULDBLOCK. @@ -1469,39 +1459,36 @@ class NetworkedTests(unittest.TestCase): errno.EWOULDBLOCK, ) self.assertIn(rc, errors) - finally: - s.close() def test_connect_with_context(self): with support.transient_internet(REMOTE_HOST): # Same as test_connect, but with a separately created context ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) s = ctx.wrap_socket(socket.socket(socket.AF_INET)) - s.connect((REMOTE_HOST, 443)) - try: + with s: + s.connect((REMOTE_HOST, 443)) self.assertEqual({}, s.getpeercert()) - finally: - s.close() + # Same with a server hostname s = ctx.wrap_socket(socket.socket(socket.AF_INET), server_hostname=REMOTE_HOST) - s.connect((REMOTE_HOST, 443)) - s.close() + with s: + s.connect((REMOTE_HOST, 443)) + # This should fail because we have no verification certs ctx.verify_mode = ssl.CERT_REQUIRED s = ctx.wrap_socket(socket.socket(socket.AF_INET)) - self.assertRaisesRegex(ssl.SSLError, "certificate verify failed", - s.connect, (REMOTE_HOST, 443)) - s.close() + with s: + self.assertRaisesRegex(ssl.SSLError, "certificate verify failed", + s.connect, (REMOTE_HOST, 443)) + # This should succeed because we specify the root cert ctx.load_verify_locations(REMOTE_ROOT_CERT) s = ctx.wrap_socket(socket.socket(socket.AF_INET)) - s.connect((REMOTE_HOST, 443)) - try: + with s: + s.connect((REMOTE_HOST, 443)) cert = s.getpeercert() self.assertTrue(cert) - finally: - s.close() def test_connect_capath(self): # Verify server certificates using the `capath` argument @@ -1514,23 +1501,20 @@ class NetworkedTests(unittest.TestCase): ctx.verify_mode = ssl.CERT_REQUIRED ctx.load_verify_locations(capath=CAPATH) s = ctx.wrap_socket(socket.socket(socket.AF_INET)) - s.connect((REMOTE_HOST, 443)) - try: + with s: + s.connect((REMOTE_HOST, 443)) cert = s.getpeercert() self.assertTrue(cert) - finally: - s.close() + # Same with a bytes `capath` argument ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) ctx.verify_mode = ssl.CERT_REQUIRED ctx.load_verify_locations(capath=BYTES_CAPATH) s = ctx.wrap_socket(socket.socket(socket.AF_INET)) - s.connect((REMOTE_HOST, 443)) - try: + with s: + s.connect((REMOTE_HOST, 443)) cert = s.getpeercert() self.assertTrue(cert) - finally: - s.close() def test_connect_cadata(self): with open(REMOTE_ROOT_CERT) as f: @@ -1561,14 +1545,15 @@ class NetworkedTests(unittest.TestCase): # file descriptor, hence skipping the test under Windows). with support.transient_internet(REMOTE_HOST): ss = ssl.wrap_socket(socket.socket(socket.AF_INET)) - ss.connect((REMOTE_HOST, 443)) - fd = ss.fileno() - f = ss.makefile() - f.close() - # The fd is still open - os.read(fd, 0) - # Closing the SSL socket should close the fd too - ss.close() + with ss: + ss.connect((REMOTE_HOST, 443)) + fd = ss.fileno() + f = ss.makefile() + f.close() + # The fd is still open + os.read(fd, 0) + # Closing the SSL socket should close the fd too + ss.close() gc.collect() with self.assertRaises(OSError) as e: os.read(fd, 0) @@ -1582,17 +1567,17 @@ class NetworkedTests(unittest.TestCase): s = ssl.wrap_socket(s, cert_reqs=ssl.CERT_NONE, do_handshake_on_connect=False) - count = 0 - while True: - try: - count += 1 - s.do_handshake() - break - except ssl.SSLWantReadError: - select.select([s], [], []) - except ssl.SSLWantWriteError: - select.select([], [s], []) - s.close() + with s: + count = 0 + while True: + try: + count += 1 + s.do_handshake() + break + except ssl.SSLWantReadError: + select.select([s], [], []) + except ssl.SSLWantWriteError: + select.select([], [s], []) if support.verbose: sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count) @@ -1658,15 +1643,13 @@ class NetworkedTests(unittest.TestCase): ctx.load_verify_locations(sha256_cert) s = ctx.wrap_socket(socket.socket(socket.AF_INET), server_hostname="sha256.tbs-internet.com") - try: + with s: s.connect(remote) if support.verbose: sys.stdout.write("\nCipher with %r is %r\n" % (remote, s.cipher())) sys.stdout.write("Certificate is:\n%s\n" % pprint.pformat(s.getpeercert())) - finally: - s.close() def test_get_ca_certs_capath(self): # capath certs are loaded on request @@ -1676,12 +1659,10 @@ class NetworkedTests(unittest.TestCase): ctx.load_verify_locations(capath=CAPATH) self.assertEqual(ctx.get_ca_certs(), []) s = ctx.wrap_socket(socket.socket(socket.AF_INET)) - s.connect((REMOTE_HOST, 443)) - try: + with s: + s.connect((REMOTE_HOST, 443)) cert = s.getpeercert() self.assertTrue(cert) - finally: - s.close() self.assertEqual(len(ctx.get_ca_certs()), 1) @needs_sni @@ -1739,52 +1720,53 @@ class NetworkedBIOTests(unittest.TestCas def test_handshake(self): with support.transient_internet(REMOTE_HOST): sock = socket.socket(socket.AF_INET) - sock.connect((REMOTE_HOST, 443)) - incoming = ssl.MemoryBIO() - outgoing = ssl.MemoryBIO() - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - ctx.verify_mode = ssl.CERT_REQUIRED - ctx.load_verify_locations(REMOTE_ROOT_CERT) - ctx.check_hostname = True - sslobj = ctx.wrap_bio(incoming, outgoing, False, REMOTE_HOST) - self.assertIs(sslobj._sslobj.owner, sslobj) - self.assertIsNone(sslobj.cipher()) - self.assertIsNone(sslobj.shared_ciphers()) - self.assertRaises(ValueError, sslobj.getpeercert) - if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: - self.assertIsNone(sslobj.get_channel_binding('tls-unique')) - self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) - self.assertTrue(sslobj.cipher()) - self.assertIsNone(sslobj.shared_ciphers()) - self.assertTrue(sslobj.getpeercert()) - if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: - self.assertTrue(sslobj.get_channel_binding('tls-unique')) - try: - self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap) - except ssl.SSLSyscallError: - # self-signed.pythontest.net probably shuts down the TCP - # connection without sending a secure shutdown message, and - # this is reported as SSL_ERROR_SYSCALL - pass - self.assertRaises(ssl.SSLError, sslobj.write, b'foo') - sock.close() + with sock: + sock.connect((REMOTE_HOST, 443)) + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.load_verify_locations(REMOTE_ROOT_CERT) + ctx.check_hostname = True + sslobj = ctx.wrap_bio(incoming, outgoing, False, REMOTE_HOST) + self.assertIs(sslobj._sslobj.owner, sslobj) + self.assertIsNone(sslobj.cipher()) + self.assertIsNone(sslobj.shared_ciphers()) + self.assertRaises(ValueError, sslobj.getpeercert) + if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: + self.assertIsNone(sslobj.get_channel_binding('tls-unique')) + self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) + self.assertTrue(sslobj.cipher()) + self.assertIsNone(sslobj.shared_ciphers()) + self.assertTrue(sslobj.getpeercert()) + if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: + self.assertTrue(sslobj.get_channel_binding('tls-unique')) + try: + self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap) + except ssl.SSLSyscallError: + # self-signed.pythontest.net probably shuts down the TCP + # connection without sending a secure shutdown message, and + # this is reported as SSL_ERROR_SYSCALL + pass + self.assertRaises(ssl.SSLError, sslobj.write, b'foo') def test_read_write_data(self): with support.transient_internet(REMOTE_HOST): sock = socket.socket(socket.AF_INET) - sock.connect((REMOTE_HOST, 443)) - incoming = ssl.MemoryBIO() - outgoing = ssl.MemoryBIO() - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - ctx.verify_mode = ssl.CERT_NONE - sslobj = ctx.wrap_bio(incoming, outgoing, False) - self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) - req = b'GET / HTTP/1.0\r\n\r\n' - self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req) - buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024) - self.assertEqual(buf[:5], b'HTTP/') - self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap) - sock.close() + with sock: + sock.connect((REMOTE_HOST, 443)) + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_NONE + sslobj = ctx.wrap_bio(incoming, outgoing, False) + self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) + req = b'GET / HTTP/1.0\r\n\r\n' + self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req) + buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, + 1024) + self.assertEqual(buf[:5], b'HTTP/') + self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap) try: @@ -2247,32 +2229,33 @@ else: with server: s = context.wrap_socket(socket.socket(), do_handshake_on_connect=False) - s.connect((HOST, server.port)) - # getpeercert() raise ValueError while the handshake isn't - # done. - with self.assertRaises(ValueError): - s.getpeercert() - s.do_handshake() - cert = s.getpeercert() - self.assertTrue(cert, "Can't get peer certificate.") - cipher = s.cipher() - if support.verbose: - sys.stdout.write(pprint.pformat(cert) + '\n') - sys.stdout.write("Connection cipher is " + str(cipher) + '.\n') - if 'subject' not in cert: - self.fail("No subject field in certificate: %s." % - pprint.pformat(cert)) - if ((('organizationName', 'Python Software Foundation'),) - not in cert['subject']): - self.fail( - "Missing or invalid 'organizationName' field in certificate subject; " - "should be 'Python Software Foundation'.") - self.assertIn('notBefore', cert) - self.assertIn('notAfter', cert) - before = ssl.cert_time_to_seconds(cert['notBefore']) - after = ssl.cert_time_to_seconds(cert['notAfter']) - self.assertLess(before, after) - s.close() + with s: + s.connect((HOST, server.port)) + # getpeercert() raise ValueError while the handshake isn't + # done. + with self.assertRaises(ValueError): + s.getpeercert() + s.do_handshake() + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + cipher = s.cipher() + if support.verbose: + sys.stdout.write(pprint.pformat(cert) + '\n') + sys.stdout.write("Connection cipher is " + str(cipher) + '.\n') + if 'subject' not in cert: + self.fail("No subject field in certificate: %s." % + pprint.pformat(cert)) + if ((('organizationName', 'Python Software Foundation'),) + not in cert['subject']): + self.fail( + "Missing or invalid 'organizationName' field in certificate subject; " + "should be 'Python Software Foundation'.") + self.assertIn('notBefore', cert) + self.assertIn('notAfter', cert) + before = ssl.cert_time_to_seconds(cert['notBefore']) + after = ssl.cert_time_to_seconds(cert['notAfter']) + self.assertLess(before, after) + s.close() @unittest.skipUnless(have_verify_flags(), "verify_flags need OpenSSL > 0.9.8") @@ -2645,7 +2628,7 @@ else: server.port, os.path.split(CERTFILE)[1]) context = ssl.create_default_context(cafile=CERTFILE) f = urllib.request.urlopen(url, context=context) - try: + with f: dlen = f.info().get("content-length") if dlen and (int(dlen) > 0): d2 = f.read(int(dlen)) @@ -2653,8 +2636,6 @@ else: sys.stdout.write( " client: read %d bytes from remote server '%s'\n" % (len(d2), server)) - finally: - f.close() self.assertEqual(d1, d2) def test_asyncore_server(self): @@ -2668,23 +2649,24 @@ else: server = AsyncoreEchoServer(CERTFILE) with server: s = ssl.wrap_socket(socket.socket()) - s.connect(('127.0.0.1', server.port)) - if support.verbose: - sys.stdout.write( - " client: sending %r...\n" % indata) - s.write(indata) - outdata = s.read() - if support.verbose: - sys.stdout.write(" client: read %r\n" % outdata) - if outdata != indata.lower(): - self.fail( - "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n" - % (outdata[:20], len(outdata), - indata[:20].lower(), len(indata))) - s.write(b"over\n") - if support.verbose: - sys.stdout.write(" client: closing connection.\n") - s.close() + with s: + s.connect(('127.0.0.1', server.port)) + if support.verbose: + sys.stdout.write( + " client: sending %r...\n" % indata) + s.write(indata) + outdata = s.read() + if support.verbose: + sys.stdout.write(" client: read %r\n" % outdata) + if outdata != indata.lower(): + self.fail( + "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n" + % (outdata[:20], len(outdata), + indata[:20].lower(), len(indata))) + s.write(b"over\n") + if support.verbose: + sys.stdout.write(" client: closing connection.\n") + s.close() if support.verbose: sys.stdout.write(" client: connection closed.\n") @@ -2817,21 +2799,22 @@ else: ca_certs=CERTFILE, cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_TLSv1) - s.connect((HOST, server.port)) - s.setblocking(False) - - # If we keep sending data, at some point the buffers - # will be full and the call will block - buf = bytearray(8192) - def fill_buffer(): - while True: - s.send(buf) - self.assertRaises((ssl.SSLWantWriteError, - ssl.SSLWantReadError), fill_buffer) - - # Now read all the output and discard it - s.setblocking(True) - s.close() + with s: + s.connect((HOST, server.port)) + s.setblocking(False) + + # If we keep sending data, at some point the buffers + # will be full and the call will block + buf = bytearray(8192) + def fill_buffer(): + while True: + s.send(buf) + self.assertRaises((ssl.SSLWantWriteError, + ssl.SSLWantReadError), fill_buffer) + + # Now read all the output and discard it + s.setblocking(True) + s.close() def test_handshake_timeout(self): # Issue #5103: SSL handshake must respect the socket timeout @@ -2859,24 +2842,21 @@ else: started.wait() try: - try: - c = socket.socket(socket.AF_INET) + c = socket.socket(socket.AF_INET) + with c: c.settimeout(0.2) c.connect((host, port)) # Will attempt handshake and time out self.assertRaisesRegex(socket.timeout, "timed out", ssl.wrap_socket, c) - finally: - c.close() - try: - c = socket.socket(socket.AF_INET) - c = ssl.wrap_socket(c) + + c = socket.socket(socket.AF_INET) + c = ssl.wrap_socket(c) + with c: c.settimeout(0.2) # Will attempt handshake and time out self.assertRaisesRegex(socket.timeout, "timed out", c.connect, (host, port)) - finally: - c.close() finally: finish = True t.join() @@ -2910,9 +2890,9 @@ else: # Client wait until server setup and perform a connect. evt.wait() client = context.wrap_socket(socket.socket()) - client.connect((host, port)) - client_addr = client.getsockname() - client.close() + with client: + client.connect((host, port)) + client_addr = client.getsockname() t.join() remote.close() server.close()