diff -r 6f8a5458108d Lib/asyncio/selector_events.py --- a/Lib/asyncio/selector_events.py Thu Dec 05 16:13:03 2013 +0100 +++ b/Lib/asyncio/selector_events.py Thu Dec 05 17:07:45 2013 +0100 @@ -583,7 +583,8 @@ # cadefault=True. if hasattr(ssl, '_create_stdlib_context'): sslcontext = ssl._create_stdlib_context( - cert_reqs=ssl.CERT_REQUIRED) + cert_reqs=ssl.CERT_REQUIRED, + check_hostname=bool(server_hostname)) else: # Fallback for Python 3.3. sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) @@ -639,17 +640,19 @@ self._loop.remove_reader(self._sock_fd) self._loop.remove_writer(self._sock_fd) - # Verify hostname if requested. peercert = self._sock.getpeercert() - if (self._server_hostname and - self._sslcontext.verify_mode != ssl.CERT_NONE): - try: - ssl.match_hostname(peercert, self._server_hostname) - except Exception as exc: - self._sock.close() - if self._waiter is not None: - self._waiter.set_exception(exc) - return + if not hasattr(self._sslcontext, 'check_hostname'): + # Verify hostname if requested, Python 3.4+ uses check_hostname + # and checks the hostname in do_handshake() + if (self._server_hostname and + self._sslcontext.verify_mode != ssl.CERT_NONE): + try: + ssl.match_hostname(peercert, self._server_hostname) + except Exception as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + return # Add extra info that becomes available after handshake. self._extra.update(peercert=peercert, diff -r 6f8a5458108d Lib/test/test_asyncio/test_events.py --- a/Lib/test/test_asyncio/test_events.py Thu Dec 05 16:13:03 2013 +0100 +++ b/Lib/test/test_asyncio/test_events.py Thu Dec 05 17:07:45 2013 +0100 @@ -17,7 +17,7 @@ import errno import unittest import unittest.mock -from test.support import find_unused_port, IPV6_ENABLED +from test.support import find_unused_port, IPV6_ENABLED, TEST_HOME_DIR from asyncio import futures @@ -30,10 +30,26 @@ from asyncio import locks +def data_file(filename): + fullname = os.path.join(TEST_HOME_DIR, filename) + if os.path.isfile(fullname): + return fullname + fullname = os.path.join(os.path.dirname(__file__), filename) + if os.path.isfile(fullname): + return fullname + raise FileNotFoundError(filename) + +ONLYCERT = data_file('ssl_cert.pem') +ONLYKEY = data_file('ssl_key.pem') +SIGNED_CERTFILE = data_file('keycert3.pem') +SIGNING_CA = data_file('pycacert.pem') + + class MyProto(protocols.Protocol): done = None def __init__(self, loop=None): + self.transport = None self.state = 'INITIAL' self.nbytes = 0 if loop is not None: @@ -587,6 +603,20 @@ # close server server.close() + def _make_ssl_server(self, factory, certfile, keyfile=None): + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.load_cert_chain(certfile, keyfile) + + f = self.loop.create_server( + factory, '127.0.0.1', 0, ssl=sslcontext) + + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + return server, host, port + @unittest.skipIf(ssl is None, 'No ssl module') def test_create_server_ssl(self): proto = None @@ -602,19 +632,7 @@ proto = MyProto(loop=self.loop) return proto - here = os.path.dirname(__file__) - sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext.load_cert_chain( - certfile=os.path.join(here, 'sample.crt'), - keyfile=os.path.join(here, 'sample.key')) - - f = self.loop.create_server( - factory, '127.0.0.1', 0, ssl=sslcontext) - - server = self.loop.run_until_complete(f) - sock = server.sockets[0] - host, port = sock.getsockname() - self.assertEqual(host, '127.0.0.1') + server, host, port = self._make_ssl_server(factory, ONLYCERT, ONLYKEY) f_c = self.loop.create_connection(ClientMyProto, host, port, ssl=test_utils.dummy_ssl_context()) @@ -646,6 +664,93 @@ # stop serving server.close() + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl_verify_failed(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto(loop=self.loop) + return proto + + server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + if sys.version_info >= (3, 4): + sslcontext_client.check_hostname = True + + # no CA loaded + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client) + with self.assertRaisesRegex(ssl.SSLError, + 'certificate verify failed '): + self.loop.run_until_complete(f_c) + + # close connection + self.assertIsNone(proto.transport) + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl_match_failed(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto(loop=self.loop) + return proto + + server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations( + cafile=SIGNING_CA) + if sys.version_info >= (3, 4): + sslcontext_client.check_hostname = True + + # incorrect server_hostname + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client) + with self.assertRaisesRegex(ssl.CertificateError, + "hostname '127.0.0.1' doesn't match 'localhost'"): + self.loop.run_until_complete(f_c) + + # close connection + proto.transport.close() + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl_verified(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto(loop=self.loop) + return proto + + server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations(cafile=SIGNING_CA) + if sys.version_info >= (3, 4): + sslcontext_client.check_hostname = True + + # Connection succeeds with correct CA and server hostname. + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client, + server_hostname='localhost') + client, pr = self.loop.run_until_complete(f_c) + + # close connection + proto.transport.close() + client.close() + server.close() + def test_create_server_sock(self): proto = futures.Future(loop=self.loop)