diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -11,6 +11,7 @@ import socket from . import base_events from . import constants from . import futures +from . import sslproto from . import transports from .log import logger @@ -372,6 +373,16 @@ class BaseProactorEventLoop(base_events. return _ProactorSocketTransport(self, sock, protocol, waiter, extra, server) + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, server_hostname=None, + extra=None, server=None): + sslproto._check_sslproto_available() + ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter, + server_side, server_hostname) + _ProactorSocketTransport(self, rawsock, ssl_protocol, + extra=extra, server=server) + return ssl_protocol._app_transport + def _make_duplex_pipe_transport(self, sock, protocol, waiter=None, extra=None): return _ProactorDuplexPipeTransport(self, @@ -454,9 +465,8 @@ class BaseProactorEventLoop(base_events. def _write_to_self(self): self._csock.send(b'\0') - def _start_serving(self, protocol_factory, sock, ssl=None, server=None): - if ssl: - raise ValueError('IocpEventLoop is incompatible with SSL.') + def _start_serving(self, protocol_factory, sock, + sslcontext=None, server=None): def loop(f=None): try: @@ -466,9 +476,15 @@ class BaseProactorEventLoop(base_events. logger.debug("%r got a new connection from %r: %r", server, addr, conn) protocol = protocol_factory() - self._make_socket_transport( - conn, protocol, - extra={'peername': addr}, server=server) + if sslcontext: + self._make_ssl_transport( + conn, protocol, + sslcontext, None, server_side=True, + extra={'peername': addr}, server=server) + else: + self._make_socket_transport( + conn, protocol, + extra={'peername': addr}, server=server) if self.is_closed(): return f = self._proactor.accept(sock) diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -10,6 +10,7 @@ import collections import errno import functools import socket +import sys try: import ssl except ImportError: # pragma: no cover @@ -21,6 +22,7 @@ from . import events from . import futures from . import selectors from . import transports +from . import sslproto from .log import logger @@ -58,6 +60,24 @@ class BaseSelectorEventLoop(base_events. def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, server_side=False, server_hostname=None, extra=None, server=None): + try: + sslproto._check_sslproto_available() + except NotImplementedError: + return self._make_legacy_ssl_transport( + rawsock, protocol, sslcontext, waiter, + server_side=server_side, server_hostname=server_hostname, + extra=extra, server=server) + + ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter, + server_side, server_hostname) + _SelectorSocketTransport(self, rawsock, ssl_protocol, + extra=extra, server=server) + return ssl_protocol._app_transport + + def _make_legacy_ssl_transport(self, rawsock, protocol, sslcontext, + waiter, *, + server_side=False, server_hostname=None, + extra=None, server=None): return _SelectorSslTransport( self, rawsock, protocol, sslcontext, waiter, server_side, server_hostname, extra, server) @@ -588,7 +608,9 @@ class _SelectorSocketTransport(_Selector except (BlockingIOError, InterruptedError): pass except Exception as exc: - self._fatal_error(exc, 'Fatal read error on socket transport') + if (sys.platform != 'win32' or not + isinstance(exc, ConnectionAbortedError)): + self._fatal_error(exc, 'Fatal read error on socket transport') else: if data: self._protocol.data_received(data) @@ -682,26 +704,8 @@ class _SelectorSslTransport(_SelectorTra if ssl is None: raise RuntimeError('stdlib ssl module not available') - if server_side: - if not sslcontext: - raise ValueError('Server side ssl needs a valid SSLContext') - else: - if not sslcontext: - # Client side may pass ssl=True to use a default - # context; in that case the sslcontext passed is None. - # The default is secure for client connections. - if hasattr(ssl, 'create_default_context'): - # Python 3.4+: use up-to-date strong settings. - sslcontext = ssl.create_default_context() - if not server_hostname: - sslcontext.check_hostname = False - else: - # Fallback for Python 3.3. - sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext.options |= ssl.OP_NO_SSLv2 - sslcontext.options |= ssl.OP_NO_SSLv3 - sslcontext.set_default_verify_paths() - sslcontext.verify_mode = ssl.CERT_REQUIRED + if not sslcontext: + sslcontext = sslproto._create_transport_context(server_side, server_hostname) wrap_kwargs = { 'server_side': server_side, diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py new file mode 100644 --- /dev/null +++ b/Lib/asyncio/sslproto.py @@ -0,0 +1,602 @@ + +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import protocols +from . import transports +from .log import logger + + +def _create_transport_context(server_side, server_hostname): + if server_side: + raise ValueError('Server side ssl needs a valid SSLContext') + else: + # Client side may pass ssl=True to use a default + # context; in that case the sslcontext passed is None. + # The default is secure for client connections. + if hasattr(ssl, 'create_default_context'): + # Python 3.4+: use up-to-date strong settings. + sslcontext = ssl.create_default_context() + if not server_hostname: + sslcontext.check_hostname = False + else: + # Fallback for Python 3.3. + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.options |= ssl.OP_NO_SSLv3 + sslcontext.set_default_verify_paths() + sslcontext.verify_mode = ssl.CERT_REQUIRED + return sslcontext + + +def _check_sslproto_available(): + if not hasattr(ssl, "MemoryBIO"): + raise NotImplementedError("SSL support not available on this " + "Python build") + + +class _SSLPipe(object): + """An SSL "Pipe". + + An SSL pipe allows you to communicate with an SSL/TLS protocol instance + through memory buffers. It can be used to implement a security layer for an + existing connection where you don't have access to the connection's file + descriptor, or for some reason you don't want to use it. + + An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode, + data is passed through untransformed. In wrapped mode, application level + data is encrypted to SSL record level data and vice versa. The SSL record + level is the lowest level in the SSL protocol suite and is what travels + as-is over the wire. + + An SslPipe initially is in "unwrapped" mode. To start SSL, call + :meth:`do_handshake`. To shutdown SSL again, call :meth:`unwrap`. + """ + + bufsize = 65536 + + # This previously used a socketpair to communicate with the SSL protocol + # instance but since October 2014 we're using a Memory BIO! This is + # cleaner, and more reliable on Windows. See for example issue #12 for more + # details. + + S_UNWRAPPED, S_DO_HANDSHAKE, S_WRAPPED, S_SHUTDOWN = range(4) + + def __init__(self, context, server_side, server_hostname=None): + """ + The *context* argument specifies the :class:`ssl.SSLContext` to use. + It is recommended to use :func:`~gruvi.ssl.create_ssl_context` so that + it will work on all supported Python versions. + + The *server_side* argument indicates whether this is a server side or + client side transport. + + The optional *server_hostname* argument can be used to specify the + hostname you are connecting to. You may only specify this parameter if + the _ssl module supports Server Name Indication (SNI). + """ + self._context = context + self._server_side = server_side + self._server_hostname = server_hostname + self._state = self.S_UNWRAPPED + self._incoming = ssl.MemoryBIO() + self._outgoing = ssl.MemoryBIO() + self._sslobj = None + self._need_ssldata = False + + @property + def context(self): + """The SSL context passed to the constructor.""" + return self._context + + @property + def ssl_object(self): + """The internal :class:`ssl.SSLObject` instance.""" + return self._sslobj + + @property + def need_ssldata(self): + """Whether more record level data is needed to complete a handshake + that is currently in progress.""" + return self._need_ssldata + + @property + def wrapped(self): + """Whether a security layer is currently in effect.""" + return self._state == self.S_WRAPPED + + def do_handshake(self, callback=None): + """Start the SSL handshake. Return a list of ssldata. + + The optional *callback* argument can be used to install a callback that + will be called when the handshake is complete. The callback will be + called with None if successful, else an exception instance. + """ + if self._state != self.S_UNWRAPPED: + raise RuntimeError('handshake in progress or completed') + wrapargs = () + self._sslobj = self._context.wrap_bio( + self._incoming, self._outgoing, + server_side=self._server_side, + server_hostname=self._server_hostname) + self._state = self.S_DO_HANDSHAKE + self._on_handshake_complete = callback + ssldata, appdata = self.feed_ssldata(b'', only_handshake=True) + assert len(appdata) == 0 + return ssldata + + def shutdown(self, callback=None): + """Start the SSL shutdown sequence. Return a list of ssldata. + + The optional *callback* argument can be used to install a callback that + will be called when the shutdown is complete. The callback will be + called without arguments. + """ + if self._state == self.S_UNWRAPPED: + raise RuntimeError('no security layer present') + self._state = self.S_SHUTDOWN + self._on_handshake_complete = callback + ssldata, appdata = self.feed_ssldata(b'') + assert appdata == [] or appdata == [b''] + return ssldata + + def feed_eof(self): + """Send a potentially "ragged" EOF. + + This method will raise an SSL_ERROR_EOF exception if the EOF is + unexpected. + """ + self._incoming.write_eof() + ssldata, appdata = self.feed_ssldata(b'') + assert appdata == [] or appdata == [b''] + + def feed_ssldata(self, data, only_handshake=False): + """Feed SSL record level data into the pipe. + + The data must be a bytes instance. It is OK to send an empty bytes + instance. This can be used to get ssldata for a handshake initiated by + this endpoint. + + Return a (ssldata, appdata) tuple. The ssldata element is a list of + buffers containing SSL data that needs to be sent to the remote SSL. + + The appdata element is a list of buffers containing plaintext data that + needs to be forwarded to the application. The appdata list may contain + an empty buffer indicating an SSL "close_notify" alert. This alert must + be acknowledged by calling :meth:`shutdown`. + """ + if self._state == self.S_UNWRAPPED: + # If unwrapped, pass plaintext data straight through. + return ([], [data] if data else []) + ssldata = []; appdata = [] + self._need_ssldata = False + if data: + self._incoming.write(data) + try: + if self._state == self.S_DO_HANDSHAKE: + # Call do_handshake() until it doesn't raise anymore. + self._sslobj.do_handshake() + self._state = self.S_WRAPPED + if self._on_handshake_complete: + self._on_handshake_complete(None) + if only_handshake: + return (ssldata, appdata) + if self._state == self.S_WRAPPED: + # Main state: read data from SSL until close_notify + while True: + chunk = self._sslobj.read(self.bufsize) + appdata.append(chunk) + if not chunk: # close_notify + break + if self._state == self.S_SHUTDOWN: + # Call shutdown() until it doesn't raise anymore. + self._sslobj.unwrap() + self._sslobj = None + self._state = self.S_UNWRAPPED + if self._on_handshake_complete: + self._on_handshake_complete() + if self._state == self.S_UNWRAPPED: + # Drain possible plaintext data after close_notify. + appdata.append(self._incoming.read()) + except (ssl.SSLError, ssl.CertificateError) as e: + if getattr(e, 'errno', -1) not in ( + ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + if self._state == self.S_DO_HANDSHAKE and self._on_handshake_complete: + self._on_handshake_complete(e) + raise + self._need_ssldata = e.errno == ssl.SSL_ERROR_WANT_READ + # Check for record level data that needs to be sent back. + # Happens for the initial handshake and renegotiations. + if self._outgoing.pending: + ssldata.append(self._outgoing.read()) + return (ssldata, appdata) + + def feed_appdata(self, data, offset=0): + """Feed plaintext data into the pipe. + + Return an (ssldata, offset) tuple. The ssldata element is a list of + buffers containing record level data that needs to be sent to the + remote SSL instance. The offset is the number of plaintext bytes that + were processed, which may be less than the length of data. + + NOTE: In case of short writes, this call MUST be retried with the SAME + buffer passed into the *data* argument (i.e. the ``id()`` must be the + same). This is an OpenSSL requirement. A further particularity is that + a short write will always have offset == 0, because the _ssl module + does not enable partial writes. And even though the offset is zero, + there will still be encrypted data in ssldata. + """ + if self._state == self.S_UNWRAPPED: + # pass through data in unwrapped mode + return ([data[offset:]] if offset < len(data) else [], len(data)) + ssldata = [] + view = memoryview(data) + while True: + self._need_ssldata = False + try: + if offset < len(view): + offset += self._sslobj.write(view[offset:]) + except ssl.SSLError as e: + # It is not allowed to call write() after unwrap() until the + # close_notify is acknowledged. We return the condition to the + # caller as a short write. + if e.reason == 'PROTOCOL_IS_SHUTDOWN': + e.errno = ssl.SSL_ERROR_WANT_READ + if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + raise + self._need_ssldata = e.errno == ssl.SSL_ERROR_WANT_READ + # See if there's any record level data back for us. + if self._outgoing.pending: + ssldata.append(self._outgoing.read()) + if offset == len(view) or self._need_ssldata: + break + return (ssldata, offset) + + +class SSLProtocol(protocols.Protocol): + + def __init__(self, loop, app_protocol, sslcontext, waiter, + server_side=False, server_hostname=None): + if ssl is None: + raise RuntimeError('stdlib ssl module not available') + + if not sslcontext: + sslcontext = _create_transport_context(server_side, server_hostname) + + self._server_side = server_side + if server_hostname and not server_side and ssl.HAS_SNI: + self._server_hostname = server_hostname + else: + self._server_hostname = None + self._sslcontext = sslcontext + # SSL-specific extra info. (peercert is set later) + self._extra = dict(sslcontext=sslcontext) + + # App data write buffering + self._write_backlog = [] + self._write_buffer_size = 0 + + self._waiter = waiter + self._closing = False + self._loop = loop + self._app_protocol = app_protocol + self._app_transport = _SSLProtocolTransport(self._loop, + self, self._app_protocol) + + def connection_made(self, transport): + """Called when the low-level connection is made. + """ + self._transport = transport + self._sslpipe = _SSLPipe(self._sslcontext, + self._server_side, + self._server_hostname) + self._session_established = False + self._in_handshake = False + self._in_shutdown = False + self._start_handshake() + + def connection_lost(self, exc): + """Called when the low-level connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + if self._session_established: + self._session_established = False + self._loop.call_soon(self._app_protocol.connection_lost, exc) + self._transport = None + self._app_transport = None + + def pause_writing(self): + """Called when the low-level transport's buffer goes over + the high-water mark. + """ + self._app_protocol.pause_writing() + + def resume_writing(self): + """Called when the low-level transport's buffer drains below + the low-water mark. + """ + self._app_protocol.resume_writing() + + def data_received(self, data): + """Called when some SSL data is received. + + The argument is a bytes object. + """ + try: + ssldata, appdata = self._sslpipe.feed_ssldata(data) + except ssl.SSLError as e: + if self._loop.get_debug(): + logger.warning('SSL error %s (reason %s)', e.errno, e.reason) + self._abort() + return + for chunk in ssldata: + self._transport.write(chunk) + for chunk in appdata: + if chunk: + self._app_protocol.data_received(chunk) + else: + self._start_shutdown() + + def eof_received(self): + """Called when the other end of the low-level stream + is half-closed. + + If this returns a false value (including None), the transport + will close itself. If it returns a true value, closing the + transport is up to the protocol. + """ + try: + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + if not self._in_handshake: + keep_open = self._app_protocol.eof_received() + if keep_open: + logger.warning('returning true from eof_received() ' + 'has no effect when using ssl') + finally: + self._transport.close() + + def _get_extra_info(self, name, default=None): + if name in self._extra: + return self._extra[name] + else: + return self._transport.get_extra_info(name, default) + + def _start_shutdown(self): + if self._in_shutdown: + return + self._in_shutdown = True + self._write_backlog.append([b'', 0]) + self._process_write_backlog() + + def _write_appdata(self, data): + self._write_backlog.append([data, 0]) + self._write_buffer_size += len(data) + self._process_write_backlog() + + def _start_handshake(self): + if self._loop.get_debug(): + logger.debug("%r starts SSL handshake", self) + self._handshake_start_time = self._loop.time() + else: + self._handshake_start_time = None + self._in_handshake = True + self._write_backlog.append([b'', 1]) + self._write_buffer_size += 1 + self._loop.call_soon(self._process_write_backlog) + + def _on_handshake_complete(self, handshake_exc): + self._in_handshake = False + + sslobj = self._sslpipe.ssl_object + peercert = None if handshake_exc else sslobj.getpeercert() + try: + if handshake_exc is not None: + raise handshake_exc + if not hasattr(self._sslcontext, 'check_hostname'): + # Verify hostname if requested, Python 3.4+ uses check_hostname + # and checks the hostname in do_handshake() + if (self._server_hostname and + self._sslcontext.verify_mode != ssl.CERT_NONE): + ssl.match_hostname(peercert, self._server_hostname) + except BaseException as exc: + if self._loop.get_debug(): + if isinstance(exc, ssl.CertificateError): + logger.warning("%r: SSL handshake failed " + "on verifying the certificate", + self, exc_info=True) + else: + logger.warning("%r: SSL handshake failed", + self, exc_info=True) + self._transport.close() + if isinstance(exc, Exception): + if self._waiter is not None: + self._waiter.set_exception(exc) + return + else: + raise + + if self._loop.get_debug(): + dt = self._loop.time() - self._handshake_start_time + logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3) + + # Add extra info that becomes available after handshake. + self._extra.update(peercert=peercert, + cipher=sslobj.cipher(), + compression=sslobj.compression(), + ) + self._app_protocol.connection_made(self._app_transport) + if self._waiter is not None: + # wait until protocol.connection_made() has been called + self._waiter._set_result_unless_cancelled(None) + self._session_established = True + # In case transport.write() was already called + self._process_write_backlog() + + def _process_write_backlog(self): + # Try to make progress on the write backlog. + if self._transport is None: + return + try: + for i in range(len(self._write_backlog)): + data, offset = self._write_backlog[0] + if data: + ssldata, offset = self._sslpipe.feed_appdata(data, offset) + elif offset: + ssldata, offset = self._sslpipe.do_handshake(self._on_handshake_complete), 1 + else: + ssldata, offset = self._sslpipe.shutdown(self._finalize), 1 + ## Temporarily set _closing to False to prevent + ## underlying write() from raising an error. + #saved, self._closing = self._closing, False + for chunk in ssldata: + self._transport.write(chunk) + #self._closing = saved + if offset < len(data): + self._write_backlog[0][1] = offset + # A short write means that a write is blocked on a read + # We need to enable reading if it is not enabled!! + assert self._sslpipe.need_ssldata + if self._transport._paused: + self._transport.resume_reading() + break + # An entire chunk from the backlog was processed. We can + # delete it and reduce the outstanding buffer size. + del self._write_backlog[0] + self._write_buffer_size -= offset + except BaseException as exc: + if self._in_handshake: + self._on_handshake_complete(exc) + else: + self._fatal_error(exc, 'Fatal error on SSL transport') + + def _fatal_error(self, exc, message='Fatal error on transport'): + # Should be called from exception handler only. + if isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self._transport, + 'protocol': self, + }) + if self._transport: + self._transport._force_close(exc) + + def _finalize(self): + if self._transport is not None: + self._transport.close() + + def _abort(self): + if self._transport is not None: + try: + self._transport.abort() + finally: + self._finalize() + + +class _SSLProtocolTransport(transports._FlowControlMixin, + transports.Transport): + + def __init__(self, loop, ssl_protocol, app_protocol): + self._loop = loop + self._ssl_protocol = ssl_protocol + self._app_protocol = app_protocol + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._ssl_protocol._get_extra_info(name, default) + + def close(self): + """Close the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + self._ssl_protocol._start_shutdown() + + def pause_reading(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume_reading() is called. + """ + self._ssl_protocol._transport.pause_reading() + + def resume_reading(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + self._ssl_protocol._transport.resume_reading() + + def set_write_buffer_limits(self, high=None, low=None): + """Set the high- and low-water limits for write flow control. + + These two values control when to call the protocol's + pause_writing() and resume_writing() methods. If specified, + the low-water limit must be less than or equal to the + high-water limit. Neither value can be negative. + + The defaults are implementation-specific. If only the + high-water limit is given, the low-water limit defaults to a + implementation-specific value less than or equal to the + high-water limit. Setting high to zero forces low to zero as + well, and causes pause_writing() to be called whenever the + buffer becomes non-empty. Setting low to zero causes + resume_writing() to be called only once the buffer is empty. + Use of zero for either limit is generally sub-optimal as it + reduces opportunities for doing I/O and computation + concurrently. + """ + self._ssl_protocol._transport.set_write_buffer_limits(high, low) + + def get_write_buffer_size(self): + """Return the current size of the write buffer.""" + return self._ssl_protocol._transport.get_write_buffer_size() + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError("data: expecting a bytes-like instance, got {!r}" + .format(type(data).__name__)) + if not data: + return + self._ssl_protocol._write_appdata(data) + + def write_eof(self): + """Close the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + + def can_write_eof(self): + """Return True if this transport supports write_eof(), False if not.""" + return False + + def abort(self): + """Close the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + self._ssl_protocol._abort() diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py --- a/Lib/asyncio/test_utils.py +++ b/Lib/asyncio/test_utils.py @@ -434,3 +434,8 @@ def mock_nonblocking_socket(): sock = mock.Mock(socket.socket) sock.gettimeout.return_value = 0.0 return sock + + +def force_legacy_ssl_support(): + return mock.patch('asyncio.sslproto._check_sslproto_available', + side_effect=NotImplementedError) diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -644,6 +644,10 @@ class EventLoopTestsMixin: *httpd.address) self._test_create_ssl_connection(httpd, create_connection) + def test_legacy_create_ssl_connection(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_ssl_connection() + @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_ssl_unix_connection(self): @@ -660,6 +664,10 @@ class EventLoopTestsMixin: self._test_create_ssl_connection(httpd, create_connection, check_sockname) + def test_legacy_create_ssl_unix_connection(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_ssl_unix_connection() + def test_create_connection_local_addr(self): with test_utils.run_test_server() as httpd: port = support.find_unused_port() @@ -820,6 +828,10 @@ class EventLoopTestsMixin: # stop serving server.close() + def test_legacy_create_server_ssl(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_server_ssl() + @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_server_ssl(self): @@ -851,6 +863,10 @@ class EventLoopTestsMixin: # stop serving server.close() + def test_legacy_create_unix_server_ssl(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_unix_server_ssl() + @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') def test_create_server_ssl_verify_failed(self): @@ -876,6 +892,10 @@ class EventLoopTestsMixin: self.assertIsNone(proto.transport) server.close() + def test_legacy_create_server_ssl_verify_failed(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_server_ssl_verify_failed() + @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') @@ -903,6 +923,10 @@ class EventLoopTestsMixin: self.assertIsNone(proto.transport) server.close() + def test_legacy_create_unix_server_ssl_verify_failed(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_unix_server_ssl_verify_failed() + @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') def test_create_server_ssl_match_failed(self): @@ -931,6 +955,10 @@ class EventLoopTestsMixin: proto.transport.close() server.close() + def test_legacy_create_server_ssl_match_failed(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_server_ssl_match_failed() + @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') @@ -956,6 +984,11 @@ class EventLoopTestsMixin: proto.transport.close() client.close() server.close() + self.loop.run_until_complete(proto.done) + + def test_legacy_create_unix_server_ssl_verified(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_unix_server_ssl_verified() @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') @@ -981,6 +1014,11 @@ class EventLoopTestsMixin: proto.transport.close() client.close() server.close() + self.loop.run_until_complete(proto.done) + + def test_legacy_create_server_ssl_verified(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_server_ssl_verified() def test_create_server_sock(self): proto = asyncio.Future(loop=self.loop) @@ -1713,20 +1751,20 @@ if sys.platform == 'win32': def create_event_loop(self): return asyncio.ProactorEventLoop() - def test_create_ssl_connection(self): - raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + def test_legacy_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL") - def test_create_server_ssl(self): - raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + def test_legacy_create_server_ssl(self): + raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL") - def test_create_server_ssl_verify_failed(self): - raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + def test_legacy_create_server_ssl_verify_failed(self): + raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL") - def test_create_server_ssl_match_failed(self): - raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + def test_legacy_create_server_ssl_match_failed(self): + raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL") - def test_create_server_ssl_verified(self): - raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + def test_legacy_create_server_ssl_verified(self): + raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL") def test_reader_callback(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -62,9 +62,13 @@ class BaseSelectorEventLoopTests(test_ut with test_utils.disable_logger(): transport = self.loop._make_ssl_transport( m, asyncio.Protocol(), m, waiter) - self.assertIsInstance(transport, _SelectorSslTransport) + # Sanity check + class_name = transport.__class__.__name__ + self.assertIn("ssl", class_name.lower()) + self.assertIn("transport", class_name.lower()) @mock.patch('asyncio.selector_events.ssl', None) + @mock.patch('asyncio.sslproto.ssl', None) def test_make_ssl_transport_without_ssl_error(self): m = mock.Mock() self.loop.add_reader = mock.Mock()