diff -r 08c67d4f8e36 Lib/asyncio/protocols.py --- a/Lib/asyncio/protocols.py Thu Jan 29 00:55:46 2015 +0100 +++ b/Lib/asyncio/protocols.py Thu Jan 29 01:00:16 2015 +0100 @@ -30,6 +30,14 @@ class BaseProtocol: aborted or closed). """ + def connection_failed(self, transport, exc): + """Called when the connection to the transport failed. + + This method is called by a server for an incoming client connection + when the creation of the transport failed, if the SSL handshake failed + for example. + """ + def pause_writing(self): """Called when the transport's buffer goes over the high-water mark. diff -r 08c67d4f8e36 Lib/asyncio/selector_events.py --- a/Lib/asyncio/selector_events.py Thu Jan 29 00:55:46 2015 +0100 +++ b/Lib/asyncio/selector_events.py Thu Jan 29 01:00:16 2015 +0100 @@ -22,6 +22,7 @@ from . import futures from . import selectors from . import transports from . import sslproto +from .coroutines import coroutine from .log import logger @@ -181,16 +182,39 @@ class BaseSelectorEventLoop(base_events. else: raise # The event loop will catch, log and ignore it. else: + accept = self._accept_connection2(protocol_factory, conn, addr, + sslcontext, server) + self.create_task(accept) + # It's now up to the protocol to handle the connection. + + @coroutine + def _accept_connection2(self, protocol_factory, conn, addr, + sslcontext=None, server=None): + try: protocol = protocol_factory() + waiter = futures.Future(loop=self) if sslcontext: - self._make_ssl_transport( - conn, protocol, sslcontext, + transport = self._make_ssl_transport( + conn, protocol, sslcontext, waiter=waiter, server_side=True, extra={'peername': addr}, server=server) else: - self._make_socket_transport( - conn, protocol , extra={'peername': addr}, + transport = self._make_socket_transport( + conn, protocol, waiter=waiter, extra={'peername': addr}, server=server) - # It's now up to the protocol to handle the connection. + except Exception as exc: + self.call_exception_handler({ + 'message': ('Error on transport creation ' + 'for incoming connection'), + 'exception': exc, + }) + return + + try: + yield from waiter + except Exception as exc: + self.call_soon(protocol.connection_failed, transport, exc) + # only close the transport when connection_failed() has been called + self.call_soon(transport.close) def add_reader(self, fd, callback, *args): """Add a reader callback.""" diff -r 08c67d4f8e36 Lib/test/test_asyncio/test_selector_events.py --- a/Lib/test/test_asyncio/test_selector_events.py Thu Jan 29 00:55:46 2015 +0100 +++ b/Lib/test/test_asyncio/test_selector_events.py Thu Jan 29 01:00:16 2015 +0100 @@ -10,6 +10,7 @@ except ImportError: ssl = None import asyncio +from asyncio import selector_events from asyncio import selectors from asyncio import test_utils from asyncio.selector_events import BaseSelectorEventLoop @@ -661,6 +662,81 @@ class BaseSelectorEventLoopTests(test_ut selectors.EVENT_WRITE)]) self.loop.remove_writer.assert_called_with(1) + def check_accept_connection_cancel(self, sslcontext): + # Test that _accept_connection() handles cancellation on the creation + # of the transport + ssock = mock.Mock() + ssock.accept.return_value = (mock.Mock(), mock.Mock()) + transport = mock.Mock() + transport.close._is_coroutine = False + + def connection_failed(transport, exc): + proto.connection_failed.exc = exc + self.loop.call_soon(self.loop.stop) + + proto = mock.Mock() + proto.connection_failed.exc = None + proto.connection_failed.side_effect = connection_failed + proto.connection_failed._is_coroutine = False + + # Hook _make_socket_transport() to cancel the waiter + def make_transport(*args, **kw): + kw['waiter'].cancel() + return transport + + self.loop.add_reader = mock.Mock() + self.loop.remove_reader = mock.Mock() + self.loop.remove_writer = mock.Mock() + + if sslcontext: + method = '_make_ssl_transport' + else: + method = '_make_socket_transport' + with mock.patch.object(self.loop, method, side_effect=make_transport): + self.loop._accept_connection(lambda: proto, ssock, sslcontext) + self.loop.run_forever() + + self.assertIsInstance(proto.connection_failed.exc, + asyncio.CancelledError) + self.assertTrue(transport.close.called) + + def test_accept_connection_cancel(self): + self.check_accept_connection_cancel(None) + + def test_accept_ssl_connection_cancel(self): + self.check_accept_connection_cancel(mock.Mock()) + + def check_accept_connection_error(self, sslcontext): + # Test that _accept_connection() handles errors on the creation + # of the transport + ssock = mock.Mock() + ssock.accept.return_value = (mock.Mock(), mock.Mock()) + exc = Exception("oops") + + def make_transport(*args, **kw): + self.loop.stop() + raise exc + + self.loop.call_exception_handler = mock.Mock() + + if sslcontext: + method = '_make_ssl_transport' + else: + method = '_make_socket_transport' + with mock.patch.object(self.loop, method, side_effect=make_transport): + self.loop._accept_connection(asyncio.Protocol, ssock, sslcontext) + self.loop.run_forever() + + self.assertTrue(self.loop.call_exception_handler.called) + context = self.loop.call_exception_handler.call_args[0][0] + self.assertIs(context['exception'], exc) + + def test_accept_connection_error(self): + self.check_accept_connection_error(None) + + def test_accept_ssl_connection_error(self): + self.check_accept_connection_error(mock.Mock()) + class SelectorTransportTests(test_utils.TestCase):