diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -21,6 +21,7 @@ from . import events from . import futures from . import selectors from . import transports +from . import sslproto from .log import logger @@ -58,9 +59,18 @@ class BaseSelectorEventLoop(base_events. def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, server_side=False, server_hostname=None, extra=None, server=None): - return _SelectorSslTransport( - self, rawsock, protocol, sslcontext, waiter, - server_side, server_hostname, extra, server) + ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter, + server_side, server_hostname) + _SelectorSocketTransport(self, rawsock, ssl_protocol, + extra=extra, server=server) + return ssl_protocol._app_transport + + #def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + #server_side=False, server_hostname=None, + #extra=None, server=None): + #return _SelectorSslTransport( + #self, rawsock, protocol, sslcontext, waiter, + #server_side, server_hostname, extra, server) def _make_datagram_transport(self, sock, protocol, address=None, waiter=None, extra=None): diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py new file mode 100644 --- /dev/null +++ b/Lib/asyncio/sslproto.py @@ -0,0 +1,590 @@ + +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import transports +from .log import logger + + +class SSLPipe(object): + """An SSL "Pipe". + + An SSL pipe allows you to communicate with an SSL/TLS protocol instance + through memory buffers. It can be used to implement a security layer for an + existing connection where you don't have access to the connection's file + descriptor, or for some reason you don't want to use it. + + An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode, + data is passed through untransformed. In wrapped mode, application level + data is encrypted to SSL record level data and vice versa. The SSL record + level is the lowest level in the SSL protocol suite and is what travels + as-is over the wire. + + An SslPipe initially is in "unwrapped" mode. To start SSL, call + :meth:`do_handshake`. To shutdown SSL again, call :meth:`unwrap`. + """ + + bufsize = 65536 + + # This previously used a socketpair to communicate with the SSL protocol + # instance but since October 2014 we're using a Memory BIO! This is + # cleaner, and more reliable on Windows. See for example issue #12 for more + # details. + + S_UNWRAPPED, S_DO_HANDSHAKE, S_WRAPPED, S_SHUTDOWN = range(4) + + def __init__(self, context, server_side, server_hostname=None): + """ + The *context* argument specifies the :class:`ssl.SSLContext` to use. + It is recommended to use :func:`~gruvi.ssl.create_ssl_context` so that + it will work on all supported Python versions. + + The *server_side* argument indicates whether this is a server side or + client side transport. + + The optional *server_hostname* argument can be used to specify the + hostname you are connecting to. You may only specify this parameter if + the _ssl module supports Server Name Indication (SNI). + """ + self._context = context + self._server_side = server_side + self._server_hostname = server_hostname + self._state = self.S_UNWRAPPED + self._incoming = ssl.MemoryBIO() + self._outgoing = ssl.MemoryBIO() + self._sslobj = None + self._need_ssldata = False + + @property + def context(self): + """The SSL context passed to the constructor.""" + return self._context + + @property + def ssl_object(self): + """The internal :class:`ssl.SSLObject` instance.""" + return self._sslobj + + @property + def need_ssldata(self): + """Whether more record level data is needed to complete a handshake + that is currently in progress.""" + return self._need_ssldata + + @property + def wrapped(self): + """Whether a security layer is currently in effect.""" + return self._state == self.S_WRAPPED + + def do_handshake(self, callback=None): + """Start the SSL handshake. Return a list of ssldata. + + The optional *callback* argument can be used to install a callback that + will be called when the handshake is complete. The callback will be + called with None if successful, else an exception instance. + """ + if self._state != self.S_UNWRAPPED: + raise RuntimeError('handshake in progress or completed') + wrapargs = () + self._sslobj = self._context.wrap_bio( + self._incoming, self._outgoing, + server_side=self._server_side, + server_hostname=self._server_hostname) + self._state = self.S_DO_HANDSHAKE + self._on_handshake_complete = callback + ssldata, appdata = self.feed_ssldata(b'', only_handshake=True) + assert len(appdata) == 0 + return ssldata + + def shutdown(self, callback=None): + """Start the SSL shutdown sequence. Return a list of ssldata. + + The optional *callback* argument can be used to install a callback that + will be called when the shutdown is complete. The callback will be + called without arguments. + """ + if self._state == self.S_UNWRAPPED: + raise RuntimeError('no security layer present') + self._state = self.S_SHUTDOWN + self._on_handshake_complete = callback + ssldata, appdata = self.feed_ssldata(b'') + assert appdata == [] or appdata == [b''] + return ssldata + + def feed_eof(self): + """Send a potentially "ragged" EOF. + + This method will raise an SSL_ERROR_EOF exception if the EOF is + unexpected. + """ + self._incoming.write_eof() + ssldata, appdata = self.feed_ssldata(b'') + assert appdata == [] or appdata == [b''] + + def feed_ssldata(self, data, only_handshake=False): + """Feed SSL record level data into the pipe. + + The data must be a bytes instance. It is OK to send an empty bytes + instance. This can be used to get ssldata for a handshake initiated by + this endpoint. + + Return a (ssldata, appdata) tuple. The ssldata element is a list of + buffers containing SSL data that needs to be sent to the remote SSL. + + The appdata element is a list of buffers containing plaintext data that + needs to be forwarded to the application. The appdata list may contain + an empty buffer indicating an SSL "close_notify" alert. This alert must + be acknowledged by calling :meth:`shutdown`. + """ + if self._state == self.S_UNWRAPPED: + # If unwrapped, pass plaintext data straight through. + return ([], [data] if data else []) + ssldata = []; appdata = [] + self._need_ssldata = False + if data: + self._incoming.write(data) + try: + if self._state == self.S_DO_HANDSHAKE: + # Call do_handshake() until it doesn't raise anymore. + self._sslobj.do_handshake() + self._state = self.S_WRAPPED + if self._on_handshake_complete: + self._on_handshake_complete(None) + if only_handshake: + return (ssldata, appdata) + if self._state == self.S_WRAPPED: + # Main state: read data from SSL until close_notify + while True: + chunk = self._sslobj.read(self.bufsize) + appdata.append(chunk) + if not chunk: # close_notify + break + if self._state == self.S_SHUTDOWN: + # Call shutdown() until it doesn't raise anymore. + self._sslobj.unwrap() + self._sslobj = None + self._state = self.S_UNWRAPPED + if self._on_handshake_complete: + self._on_handshake_complete() + if self._state == self.S_UNWRAPPED: + # Drain possible plaintext data after close_notify. + appdata.append(self._incoming.read()) + except (ssl.SSLError, ssl.CertificateError) as e: + if getattr(e, 'errno', -1) not in ( + ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + if self._state == self.S_DO_HANDSHAKE and self._on_handshake_complete: + self._on_handshake_complete(e) + raise + self._need_ssldata = e.errno == ssl.SSL_ERROR_WANT_READ + # Check for record level data that needs to be sent back. + # Happens for the initial handshake and renegotiations. + if self._outgoing.pending: + ssldata.append(self._outgoing.read()) + return (ssldata, appdata) + + def feed_appdata(self, data, offset=0): + """Feed plaintext data into the pipe. + + Return an (ssldata, offset) tuple. The ssldata element is a list of + buffers containing record level data that needs to be sent to the + remote SSL instance. The offset is the number of plaintext bytes that + were processed, which may be less than the length of data. + + NOTE: In case of short writes, this call MUST be retried with the SAME + buffer passed into the *data* argument (i.e. the ``id()`` must be the + same). This is an OpenSSL requirement. A further particularity is that + a short write will always have offset == 0, because the _ssl module + does not enable partial writes. And even though the offset is zero, + there will still be encrypted data in ssldata. + """ + if self._state == self.S_UNWRAPPED: + # pass through data in unwrapped mode + return ([data[offset:]] if offset < len(data) else [], len(data)) + ssldata = [] + view = memoryview(data) + while True: + self._need_ssldata = False + try: + if offset < len(view): + offset += self._sslobj.write(view[offset:]) + except ssl.SSLError as e: + # It is not allowed to call write() after unwrap() until the + # close_notify is acknowledged. We return the condition to the + # caller as a short write. + if e.reason == 'PROTOCOL_IS_SHUTDOWN': + e.errno = ssl.SSL_ERROR_WANT_READ + if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + raise + self._need_ssldata = e.errno == ssl.SSL_ERROR_WANT_READ + # See if there's any record level data back for us. + if self._outgoing.pending: + ssldata.append(self._outgoing.read()) + if offset == len(view) or self._need_ssldata: + break + return (ssldata, offset) + + +class SSLProtocol: + + def __init__(self, loop, app_protocol, sslcontext, waiter, + server_side=False, server_hostname=None): + if ssl is None: + raise RuntimeError('stdlib ssl module not available') + + if server_side: + if not sslcontext: + raise ValueError('Server side ssl needs a valid SSLContext') + else: + if not sslcontext: + # Client side may pass ssl=True to use a default + # context; in that case the sslcontext passed is None. + # The default is the same as used by urllib with + # cadefault=True. + if hasattr(ssl, '_create_stdlib_context'): + sslcontext = ssl._create_stdlib_context( + cert_reqs=ssl.CERT_REQUIRED, + check_hostname=bool(server_hostname)) + else: + # Fallback for Python 3.3. + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.set_default_verify_paths() + sslcontext.verify_mode = ssl.CERT_REQUIRED + + self._server_side = server_side + if server_hostname and not server_side and ssl.HAS_SNI: + self._server_hostname = server_hostname + else: + self._server_hostname = None + self._sslcontext = sslcontext + # SSL-specific extra info. (peercert is set later) + self._extra = dict(sslcontext=sslcontext) + + # App data write buffering + self._write_backlog = [] + self._write_buffer_size = 0 + + self._waiter = waiter + self._closing = False + self._loop = loop + self._app_protocol = app_protocol + self._app_transport = SSLProtocolTransport(self._loop, + self, self._app_protocol) + + def connection_made(self, transport): + """Called when the low-level connection is made. + """ + self._transport = transport + self._sslpipe = SSLPipe(self._sslcontext, + self._server_side, + self._server_hostname) + self._in_handshake = False + self._in_shutdown = False + self._start_handshake() + + def connection_lost(self, exc): + """Called when the low-level connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + #print("-> ssl connection_lost") + self._loop.call_soon(self._app_protocol.connection_lost, exc) + self._transport = None + self._app_transport = None + + def pause_writing(self): + """Called when the low-level transport's buffer goes over + the high-water mark. + """ + #print("-> ssl pause_writing") + self._app_protocol.pause_writing() + + def resume_writing(self): + """Called when the low-level transport's buffer drains below + the low-water mark. + """ + #print("-> ssl resume_writing") + self._app_protocol.resume_writing() + + def data_received(self, data): + """Called when some SSL data is received. + + The argument is a bytes object. + """ + #print("-> ssl data_received:", data[:20]) + try: + ssldata, appdata = self._sslpipe.feed_ssldata(data) + except ssl.SSLError as e: + logger.warning('SSL error %s (reason %s)', e.errno, e.reason) + self.abort() + return + #print("ssldata, appdata =", ssldata, appdata) + for chunk in ssldata: + self._transport.write(chunk) + #if appdata and self._paused: + #self._loop.remove_reader(self._sock_fd) + for chunk in appdata: + if chunk: + self._app_protocol.data_received(chunk) + else: + self._start_shutdown() + + def eof_received(self): + """Called when the other end of the low-level stream + is half-closed. + + If this returns a false value (including None), the transport + will close itself. If it returns a true value, closing the + transport is up to the protocol. + """ + #print("-> ssl eof_received") + try: + #self._sslpipe.feed_eof() + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + if not self._in_handshake: + keep_open = self._app_protocol.eof_received() + if keep_open: + logger.warning('returning true from eof_received() ' + 'has no effect when using ssl') + finally: + self._transport.close() + + def _get_extra_info(self, name, default=None): + if name in self._extra: + return self._extra[name] + else: + return self._transport.get_extra_info(name, default) + + def _start_shutdown(self): + if self._in_shutdown: + return + self._in_shutdown = True + self._write_backlog.append([b'', 0]) + #print("_start_shutdown") + self._process_write_backlog() + + def _write_appdata(self, data): + self._write_backlog.append([data, 0]) + self._write_buffer_size += len(data) + self._process_write_backlog() + + def _start_handshake(self): + if self._loop.get_debug(): + logger.debug("%r starts SSL handshake", self) + self._handshake_start_time = self._loop.time() + else: + self._handshake_start_time = None + self._in_handshake = True + self._write_backlog.append([b'', 1]) + self._write_buffer_size += 1 + #print("_start_handshake") + self._loop.call_soon(self._process_write_backlog) + + def _on_handshake_complete(self, handshake_exc): + self._in_handshake = False + + sslobj = self._sslpipe.ssl_object + peercert = None if handshake_exc else sslobj.getpeercert() + try: + if handshake_exc is not None: + raise handshake_exc + if not hasattr(self._sslcontext, 'check_hostname'): + # Verify hostname if requested, Python 3.4+ uses check_hostname + # and checks the hostname in do_handshake() + if (self._server_hostname and + self._sslcontext.verify_mode != ssl.CERT_NONE): + ssl.match_hostname(peercert, self._server_hostname) + except BaseException as exc: + if self._loop.get_debug(): + if isinstance(exc, ssl.CertificateError): + logger.warning("%r: SSL handshake failed " + "on verifying the certificate", + self, exc_info=True) + else: + logger.warning("%r: SSL handshake failed", + self, exc_info=True) + self._transport.close() + if isinstance(exc, Exception): + if self._waiter is not None: + self._waiter.set_exception(exc) + return + else: + raise + + if self._loop.get_debug(): + dt = self._loop.time() - self._handshake_start_time + logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3) + + # Add extra info that becomes available after handshake. + self._extra.update(peercert=peercert, + cipher=sslobj.cipher(), + compression=sslobj.compression(), + ) + self._app_protocol.connection_made(self._app_transport) + if self._waiter is not None: + # wait until protocol.connection_made() has been called + self._waiter._set_result_unless_cancelled(None) + # In case transport.write() was already called + self._process_write_backlog() + + def _process_write_backlog(self): + # Try to make progress on the write backlog. + if self._transport is None: + return + try: + for i in range(len(self._write_backlog)): + data, offset = self._write_backlog[0] + if data: + ssldata, offset = self._sslpipe.feed_appdata(data, offset) + elif offset: + ssldata, offset = self._sslpipe.do_handshake(self._on_handshake_complete), 1 + else: + ssldata, offset = self._sslpipe.shutdown(self._finalize), 1 + ## Temporarily set _closing to False to prevent + ## underlying write() from raising an error. + #saved, self._closing = self._closing, False + for chunk in ssldata: + self._transport.write(chunk) + #self._closing = saved + if offset < len(data): + self._write_backlog[0][1] = offset + # A short write means that a write is blocked on a read + # We need to enable reading if it is not enabled!! + assert self._sslpipe.need_ssldata + if self._transport._paused: + self._transport.resume_reading() + break + # An entire chunk from the backlog was processed. We can + # delete it and reduce the outstanding buffer size. + del self._write_backlog[0] + self._write_buffer_size -= offset + except BaseException as exc: + if self._in_handshake: + self._on_handshake_complete(exc) + else: + self._fatal_error(exc, 'Fatal error on SSL transport') + + def _fatal_error(self, exc, message='Fatal error on transport'): + # Should be called from exception handler only. + if isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self._transport, + 'protocol': self, + }) + if self._transport: + self._transport._force_close(exc) + + def _finalize(self): + if self._transport is not None: + self._transport.close() + + +class SSLProtocolTransport(transports._FlowControlMixin, + transports.Transport): + + def __init__(self, loop, ssl_protocol, app_protocol): + self._loop = loop + self._ssl_protocol = ssl_protocol + self._app_protocol = app_protocol + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._ssl_protocol._get_extra_info(name, default) + + def close(self): + """Close the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + self._ssl_protocol._start_shutdown() + + def pause_reading(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume_reading() is called. + """ + + def resume_reading(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + + def set_write_buffer_limits(self, high=None, low=None): + """Set the high- and low-water limits for write flow control. + + These two values control when to call the protocol's + pause_writing() and resume_writing() methods. If specified, + the low-water limit must be less than or equal to the + high-water limit. Neither value can be negative. + + The defaults are implementation-specific. If only the + high-water limit is given, the low-water limit defaults to a + implementation-specific value less than or equal to the + high-water limit. Setting high to zero forces low to zero as + well, and causes pause_writing() to be called whenever the + buffer becomes non-empty. Setting low to zero causes + resume_writing() to be called only once the buffer is empty. + Use of zero for either limit is generally sub-optimal as it + reduces opportunities for doing I/O and computation + concurrently. + """ + + def get_write_buffer_size(self): + """Return the current size of the write buffer.""" + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + #print("write:", data) + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError("data: expecting a bytes-like instance, got {!r}" + .format(type(data).__name__)) + if not data: + return + self._ssl_protocol._write_appdata(data) + + def write_eof(self): + """Close the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + + def can_write_eof(self): + """Return True if this transport supports write_eof(), False if not.""" + return False + + def abort(self): + """Close the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + try: + self._ssl_protocol._transport.abort() + finally: + self._ssl_protocol._finalize() + diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -928,6 +928,7 @@ class EventLoopTestsMixin: proto.transport.close() client.close() server.close() + self.loop.run_until_complete(proto.done) @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') @@ -953,6 +954,7 @@ class EventLoopTestsMixin: proto.transport.close() client.close() server.close() + self.loop.run_until_complete(proto.done) def test_create_server_sock(self): proto = asyncio.Future(loop=self.loop) diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -62,9 +62,10 @@ class BaseSelectorEventLoopTests(test_ut with test_utils.disable_logger(): transport = self.loop._make_ssl_transport( m, asyncio.Protocol(), m, waiter) - self.assertIsInstance(transport, _SelectorSslTransport) + #self.assertIsInstance(transport, _SelectorSslTransport) @mock.patch('asyncio.selector_events.ssl', None) + @mock.patch('asyncio.sslproto.ssl', None) def test_make_ssl_transport_without_ssl_error(self): m = mock.Mock() self.loop.add_reader = mock.Mock()