import asyncio import socket from test import support from test.test_asyncio import utils as test_utils def tearDownModule(): asyncio.set_event_loop_policy(None) class _SendfileFallbackProtocol(asyncio.Protocol): def __init__(self, transp): self._transport = transp self._proto = transp.get_protocol() self._should_resume_reading = transp.is_reading() self._should_resume_writing = transp._protocol_paused transp.pause_reading() transp.set_protocol(self) if self._should_resume_writing: self._write_ready_fut = self._transport._loop.create_future() else: self._write_ready_fut = None async def drain(self): if self._transport.is_closing(): raise ConnectionError("Connection closed by peer") fut = self._write_ready_fut if fut is None: return await fut def connection_made(self, transport): raise RuntimeError("Invalid state") def connection_lost(self, exc): if self._write_ready_fut is not None: if exc is None: self._write_ready_fut.set_exception( ConnectionError("Connection is closed by peer")) else: self._write_ready_fut.set_exception(exc) self._proto.connection_lost(exc) def pause_writing(self): if self._write_ready_fut is not None: return self._write_ready_fut = self._transport._loop.create_future() def resume_writing(self): if self._write_ready_fut is None: return self._write_ready_fut.set_result(False) self._write_ready_fut = None def data_received(self, data): raise RuntimeError("Invalid state") def eof_received(self): raise RuntimeError("Invalid state") async def restore(self): self._transport.set_protocol(self._proto) if self._should_resume_reading: self._transport.resume_reading() if self._write_ready_fut is not None: self._write_ready_fut.cancel() if self._should_resume_writing: self._proto.resume_writing() class MySendfileProto(asyncio.Protocol): def __init__(self, loop=None, close_after=0): self.transport = None self.state = 'INITIAL' self.nbytes = 0 if loop is not None: self.connected = loop.create_future() self.done = loop.create_future() self.data = bytearray() self.close_after = close_after def connection_made(self, transport): self.transport = transport assert self.state == 'INITIAL', self.state self.state = 'CONNECTED' if self.connected: self.connected.set_result(None) def eof_received(self): assert self.state == 'CONNECTED', self.state self.state = 'EOF' def connection_lost(self, exc): assert self.state in ('CONNECTED', 'EOF'), self.state self.state = 'CLOSED' if self.done: self.done.set_result(None) def data_received(self, data): assert self.state == 'CONNECTED', self.state self.nbytes += len(data) self.data.extend(data) super().data_received(data) if self.close_after and self.nbytes >= self.close_after: self.transport.close() async def sendfile(transp): proto = _SendfileFallbackProtocol(transp) try: data = b'x' * (1024 * 24) while True: await proto.drain() transp.write(data) finally: await proto.restore() class ProactorEventLoopTests(test_utils.TestCase): def run_loop(self, coro): return self.loop.run_until_complete(coro) def test_sendfile_bug(self): self.loop = asyncio.ProactorEventLoop() self.set_event_loop(self.loop) port = support.find_unused_port() srv_proto = MySendfileProto(loop=self.loop, close_after=1024) srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) srv_sock.bind((support.HOST, port)) server = self.run_loop(self.loop.create_server( lambda: srv_proto, sock=srv_sock)) srv_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 4 * 1024) cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) cli_sock.connect((support.HOST, port)) cli_proto = MySendfileProto(loop=self.loop) tr, pr = self.run_loop(self.loop.create_connection( lambda: cli_proto, sock=cli_sock)) cli_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 4 * 1024) tr.set_write_buffer_limits(high=4 * 1024) with self.assertRaises(ConnectionError): self.run_loop( sendfile(cli_proto.transport)) self.run_loop(srv_proto.done) srv_proto.transport.close() cli_proto.transport.close() self.run_loop(srv_proto.done) self.run_loop(cli_proto.done) server.close() self.run_loop(server.wait_closed()) test_utils.run_briefly(self.loop) self.loop.close() support.gc_collect()