diff --git a/asyncio/sslproto.py b/asyncio/sslproto.py index c2c4b95..ebdf3ea 100644 --- a/asyncio/sslproto.py +++ b/asyncio/sslproto.py @@ -7,6 +7,7 @@ from . import base_events from . import compat +from . import futures from . import protocols from . import transports from .log import logger @@ -411,7 +412,7 @@ class SSLProtocol(protocols.Protocol): def __init__(self, loop, app_protocol, sslcontext, waiter, server_side=False, server_hostname=None, - call_connection_made=True): + call_connection_made=True, shutdown_timeout=5.0): if ssl is None: raise RuntimeError('stdlib ssl module not available') @@ -442,6 +443,8 @@ def __init__(self, loop, app_protocol, sslcontext, waiter, self._session_established = False self._in_handshake = False self._in_shutdown = False + self._shutdown_timeout = shutdown_timeout + self._shutdown_timeout_handle = None # transport, ex: SelectorSocketTransport self._transport = None self._call_connection_made = call_connection_made @@ -551,6 +554,15 @@ def _start_shutdown(self): self._in_shutdown = True self._write_appdata(b'') + if self._shutdown_timeout is not None: + self._shutdown_timeout_handle = self._loop.call_later( + self._shutdown_timeout, self._start_shutdown_timeout) + + def _start_shutdown_timeout(self): + if self._transport is not None: + self._fatal_error( + futures.TimeoutError(), 'Can not complete shitdown operation') + def _write_appdata(self, data): self._write_backlog.append((data, 0)) self._write_buffer_size += len(data) @@ -682,6 +694,9 @@ def _fatal_error(self, exc, message='Fatal error on transport'): def _finalize(self): if self._transport is not None: self._transport.close() + if self._shutdown_timeout_handle is not None: + self._shutdown_timeout_handle.cancel() + self._shutdown_timeout_handle = None def _abort(self): if self._transport is not None: