diff -r 53e94a687570 Lib/asyncio/selector_events.py --- a/Lib/asyncio/selector_events.py Tue Jan 27 11:10:18 2015 -0500 +++ b/Lib/asyncio/selector_events.py Tue Jan 27 23:38:57 2015 +0100 @@ -22,6 +22,7 @@ from . import futures from . import selectors from . import transports from . import sslproto +from .coroutines import coroutine from .log import logger @@ -181,16 +182,41 @@ class BaseSelectorEventLoop(base_events. else: raise # The event loop will catch, log and ignore it. else: + accept = self._accept_connection2(protocol_factory, conn, addr, + sslcontext, server) + self.create_task(accept) + # It's now up to the protocol to handle the connection. + + @coroutine + def _accept_connection2(self, protocol_factory, conn, addr, + sslcontext=None, server=None): + try: protocol = protocol_factory() + waiter = futures.Future(loop=self) if sslcontext: - self._make_ssl_transport( - conn, protocol, sslcontext, + transport = self._make_ssl_transport( + conn, protocol, sslcontext, waiter=waiter, server_side=True, extra={'peername': addr}, server=server) else: - self._make_socket_transport( - conn, protocol , extra={'peername': addr}, + transport = self._make_socket_transport( + conn, protocol, waiter=waiter, extra={'peername': addr}, server=server) - # It's now up to the protocol to handle the connection. + + # Ensure that the transport has a _force_close() method + assert isinstance(transport, _SelectorTransport) + except Exception as exc: + self.call_exception_handler({ + 'message': ('Error on transport creation ' + 'for incoming connection'), + 'exception': exc, + }) + return + + try: + yield from waiter + except Exception as exc: + # _force_close() should call protocol.connection_lost(exc) + transport._force_close(exc) def add_reader(self, fd, callback, *args): """Add a reader callback.""" diff -r 53e94a687570 Lib/test/test_asyncio/test_events.py --- a/Lib/test/test_asyncio/test_events.py Tue Jan 27 11:10:18 2015 -0500 +++ b/Lib/test/test_asyncio/test_events.py Tue Jan 27 23:38:57 2015 +0100 @@ -87,7 +87,6 @@ class MyBaseProto(asyncio.Protocol): self.state = 'EOF' def connection_lost(self, exc): - assert self.state in ('CONNECTED', 'EOF'), self.state self.state = 'CLOSED' if self.done: self.done.set_result(None) diff -r 53e94a687570 Lib/test/test_asyncio/test_selector_events.py --- a/Lib/test/test_asyncio/test_selector_events.py Tue Jan 27 11:10:18 2015 -0500 +++ b/Lib/test/test_asyncio/test_selector_events.py Tue Jan 27 23:38:57 2015 +0100 @@ -10,6 +10,7 @@ except ImportError: ssl = None import asyncio +from asyncio import selector_events from asyncio import selectors from asyncio import test_utils from asyncio.selector_events import BaseSelectorEventLoop @@ -659,6 +660,88 @@ class BaseSelectorEventLoopTests(test_ut selectors.EVENT_WRITE)]) self.loop.remove_writer.assert_called_with(1) + def check_accept_connection_cancel(self, sslcontext): + # Test that _accept_connection() handles cancellation on the creation + # of the transport + ssock = mock.Mock() + ssock.accept.return_value = (mock.Mock(), mock.Mock()) + + def connection_lost(exc): + proto.connection_lost.exc = exc + self.loop.stop() + + proto = mock.Mock() + proto.connection_lost.exc = None + proto.connection_lost.side_effect = connection_lost + proto.connection_made = lambda transport: None + + # Hook _make_socket_transport() to cancel the waiter + if sslcontext: + def make_transport(rawsock, protocol, sslcontext, + waiter=None, **kw): + waiter.cancel() + return make_ssl_transport(rawsock, protocol, sslcontext, + waiter, **kw) + make_ssl_transport = self.loop._make_ssl_transport + patch = mock.patch.object(self.loop, '_make_ssl_transport', + side_effect=make_transport) + else: + def make_transport(sock, protocol, waiter=None, **kw): + waiter.cancel() + return make_socket_transport(sock, protocol, waiter, **kw) + make_socket_transport = self.loop._make_socket_transport + patch = mock.patch.object(self.loop, '_make_socket_transport', + side_effect=make_transport) + + self.loop.add_reader = mock.Mock() + self.loop.remove_reader = mock.Mock() + self.loop.remove_writer = mock.Mock() + + with patch: + self.loop._accept_connection(lambda: proto, ssock, sslcontext) + self.loop.run_forever() + + self.assertIsInstance(proto.connection_lost.exc, + asyncio.CancelledError) + + def test_accept_connection_cancel(self): + self.check_accept_connection_cancel(None) + + def test_accept_ssl_connection_cancel(self): + self.check_accept_connection_cancel(mock.Mock()) + + def check_accept_connection_error(self, sslcontext): + # Test that _accept_connection() handles errors on the creation + # of the transport + ssock = mock.Mock() + ssock.accept.return_value = (mock.Mock(), mock.Mock()) + exc = Exception("oops") + + def make_transport(*args, **kw): + self.loop.stop() + raise exc + + if sslcontext: + method = '_make_ssl_transport' + else: + method = '_make_socket_transport' + + self.loop.call_exception_handler = mock.Mock() + + with mock.patch.object(self.loop, method, side_effect=make_transport): + self.loop._accept_connection(asyncio.Protocol, ssock, sslcontext) + self.loop.run_forever() + + self.assertTrue(self.loop.call_exception_handler.called) + context = self.loop.call_exception_handler.call_args[0][0] + self.assertIs(context['exception'], exc) + + def test_accept_connection_error(self): + self.check_accept_connection_error(None) + + def test_accept_ssl_connection_error(self): + self.check_accept_connection_error(mock.Mock()) + class SelectorTransportTests(test_utils.TestCase):