diff -r 21f6c4378846 Lib/asyncio/selector_events.py --- a/Lib/asyncio/selector_events.py Mon Oct 26 04:37:55 2015 -0400 +++ b/Lib/asyncio/selector_events.py Mon Oct 26 16:12:20 2015 -0400 @@ -56,8 +56,8 @@ def _make_socket_transport(self, sock, protocol, waiter=None, *, extra=None, server=None): - return _SelectorSocketTransport(self, sock, protocol, waiter, - extra, server) + return _SelectorSocketStartTLSTransport(self, sock, protocol, waiter, + extra, server) def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, *, server_side=False, server_hostname=None, @@ -745,6 +745,11 @@ return True +class _SelectorSocketStartTLSTransport(_SelectorSocketTransport, + sslproto.TLSTransportMixin): + pass + + class _SelectorSslTransport(_SelectorTransport): _buffer_factory = bytearray diff -r 21f6c4378846 Lib/asyncio/sslproto.py --- a/Lib/asyncio/sslproto.py Mon Oct 26 04:37:55 2015 -0400 +++ b/Lib/asyncio/sslproto.py Mon Oct 26 16:12:20 2015 -0400 @@ -400,7 +400,8 @@ """ def __init__(self, loop, app_protocol, sslcontext, waiter, - server_side=False, server_hostname=None): + server_side=False, server_hostname=None, + call_connection_made=True): if ssl is None: raise RuntimeError('stdlib ssl module not available') @@ -433,6 +434,7 @@ self._in_shutdown = False # transport, ex: SelectorSocketTransport self._transport = None + self._call_connection_made = call_connection_made def _wakeup_waiter(self, exc=None): if self._waiter is None: @@ -596,7 +598,8 @@ compression=sslobj.compression(), ssl_object=sslobj, ) - self._app_protocol.connection_made(self._app_transport) + if self._call_connection_made: + self._app_protocol.connection_made(self._app_transport) self._wakeup_waiter() self._session_established = True # In case transport.write() was already called. Don't call @@ -675,3 +678,28 @@ self._transport.abort() finally: self._finalize() + + +class TLSTransportMixin(transports.TLSTransport): + + if _is_sslproto_available(): + + def start_tls(self, sslcontext, *, + server_side=False, + server_hostname=None, + waiter=None): + + loop = self._loop + app_protocol = self._protocol + + ssl_protocol = SSLProtocol(loop=self._loop, + app_protocol=app_protocol, + sslcontext=sslcontext, + waiter=waiter, + server_side=server_side, + server_hostname=server_hostname, + call_connection_made=False) + + self._protocol = ssl_protocol + ssl_protocol.connection_made(self) + return ssl_protocol._app_transport diff -r 21f6c4378846 Lib/asyncio/streams.py --- a/Lib/asyncio/streams.py Mon Oct 26 04:37:55 2015 -0400 +++ b/Lib/asyncio/streams.py Mon Oct 26 16:12:20 2015 -0400 @@ -14,6 +14,7 @@ from . import compat from . import events from . import futures +from . import sslproto from . import protocols from .coroutines import coroutine from .log import logger @@ -215,6 +216,7 @@ self._stream_reader = stream_reader self._stream_writer = None self._client_connected_cb = client_connected_cb + self._transport = None def connection_made(self, transport): self._stream_reader.set_transport(transport) @@ -226,6 +228,7 @@ self._stream_writer) if coroutines.iscoroutine(res): self._loop.create_task(res) + self._transport = transport def connection_lost(self, exc): if exc is None: @@ -241,6 +244,23 @@ self._stream_reader.feed_eof() return True + def _start_tls(self, sslcontext, *, + server_side=False, + server_hostname=None, + waiter=None): + + new_transport = self._transport.start_tls( + sslcontext, + server_side=server_side, + server_hostname=server_hostname, + waiter=waiter) + + self._stream_reader._transport = new_transport + if self._stream_writer is not None: + self._stream_writer._transport = new_transport + + return new_transport + class StreamWriter: """Wraps a Transport. @@ -312,6 +332,29 @@ yield yield from self._protocol._drain_helper() + @coroutine + def start_tls(self, sslcontext, *, + server_side=False, + server_hostname=None): + + if not sslproto._is_sslproto_available(): + # Python 3.5 or greater is required + raise NotImplementedError + + yield from self.drain() + + waiter = futures.Future(loop=self._loop) + + new_transport = self._protocol._start_tls( + sslcontext, + server_side=server_side, + server_hostname=server_hostname, + waiter=waiter) + + self._transport = new_transport + + yield from waiter + class StreamReader: diff -r 21f6c4378846 Lib/asyncio/transports.py --- a/Lib/asyncio/transports.py Mon Oct 26 04:37:55 2015 -0400 +++ b/Lib/asyncio/transports.py Mon Oct 26 16:12:20 2015 -0400 @@ -215,6 +215,11 @@ raise NotImplementedError +class TLSTransport(Transport): + def start_tls(self, sslcontetx, *, server_side=False, server_hostname=None): + raise NotImplementedError + + class _FlowControlMixin(Transport): """All the logic for (write) flow control in a mix-in base class.