diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index a82cc79..dbb69d5 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -6,6 +6,7 @@ __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', 'LimitOverrunError', ] +import functools import socket if hasattr(socket, 'AF_UNIX'): @@ -145,6 +146,15 @@ if hasattr(socket, 'AF_UNIX'): return (yield from loop.create_unix_server(factory, path, **kwds)) +def _copy_result(waiter, drain_waiter): + if drain_waiter.cancelled(): + waiter.cancel() + elif drain_waiter.exception() is not None: + waiter.set_exception(drain_waiter.exception()) + else: + waiter.set_result(drain_waiter.result()) + + class FlowControlMixin(protocols.Protocol): """Reusable flow control logic for StreamWriter.drain(). @@ -204,10 +214,19 @@ class FlowControlMixin(protocols.Protocol): raise ConnectionResetError('Connection lost') if not self._paused: return - waiter = self._drain_waiter - assert waiter is None or waiter.cancelled() + drain_waiter = self._drain_waiter + if drain_waiter is None: + assert drain_waiter is None or drain_waiter.cancelled() + drain_waiter = self._loop.create_future() + self._drain_waiter = drain_waiter + else: + # FIXME: how should we handle cancellation here? + pass + waiter = self._loop.create_future() - self._drain_waiter = waiter + cb = functools.partial(_copy_result, waiter) + drain_waiter.add_done_callback(cb) + yield from waiter