diff --git a/Lib/asynchat.py b/Lib/asynchat.py --- a/Lib/asynchat.py +++ b/Lib/asynchat.py @@ -78,6 +78,7 @@ encoding = 'latin-1' def __init__ (self, sock=None, map=None): + self._output_drained = False # for string terminator matching self.ac_in_buffer = b'' @@ -193,10 +194,20 @@ def handle_write (self): self.initiate_send() - def handle_close (self): + def handle_close(self): + if not self._output_drained: + self._output_drained = True + self._closing = True + # try to drain the output buffer (through multiple close events) + if self.producer_fifo: + if self.initiate_send() > 0 and self.producer_fifo: + self._output_drained = False + return self.close() def push (self, data): + if self._closing: + return sabs = self.ac_out_buffer_size if len(data) > sabs: for i in range(0, len(data), sabs): @@ -206,6 +217,8 @@ self.initiate_send() def push_with_producer (self, producer): + if self._closing: + return self.producer_fifo.append(producer) self.initiate_send() @@ -234,7 +247,7 @@ if first is None: ## print("first is None") self.handle_close() - return + return 0 ## print("first is not None") # handle classic producer behavior @@ -257,7 +270,7 @@ num_sent = self.send(data) except socket.error: self.handle_error() - return + return 0 if num_sent: if num_sent < len(data) or obs < len(first): @@ -265,7 +278,7 @@ else: del self.producer_fifo[0] # we tried to send some actual data - return + return num_sent def discard_buffers (self): # Emergencies only! diff --git a/Lib/asyncore.py b/Lib/asyncore.py --- a/Lib/asyncore.py +++ b/Lib/asyncore.py @@ -110,7 +110,10 @@ obj.handle_write_event() if flags & select.POLLPRI: obj.handle_expt_event() - if flags & (select.POLLHUP | select.POLLERR | select.POLLNVAL): + # close on POLLHUP after incoming data has been drained + if flags & (select.POLLHUP | select.POLLIN) == select.POLLHUP: + obj.handle_close() + if flags & (select.POLLERR | select.POLLNVAL): obj.handle_close() except socket.error as e: if e.args[0] not in _DISCONNECTED: @@ -181,9 +184,6 @@ if obj.writable() and not obj.accepting: flags |= select.POLLOUT if flags: - # Only check for exceptions if object was either readable - # or writable. - flags |= select.POLLERR | select.POLLHUP | select.POLLNVAL pollster.register(fd, flags) try: r = pollster.poll(timeout) @@ -220,7 +220,7 @@ debug = False connected = False accepting = False - closing = False + _closing = False addr = None ignore_log_types = frozenset(['warning']) @@ -530,11 +530,12 @@ def __init__(self, sock=None, map=None): dispatcher.__init__(self, sock, map) self.out_buffer = b'' + self._output_drained = False def initiate_send(self): - num_sent = 0 num_sent = dispatcher.send(self, self.out_buffer[:512]) self.out_buffer = self.out_buffer[num_sent:] + return num_sent def handle_write(self): self.initiate_send() @@ -543,11 +544,24 @@ return (not self.connected) or len(self.out_buffer) def send(self, data): + if self._closing: + return if self.debug: self.log_info('sending %s' % repr(data)) self.out_buffer = self.out_buffer + data self.initiate_send() + def handle_close(self): + if not self._output_drained: + self._output_drained = True + self._closing = True + # try to drain the output buffer (through multiple close events) + if len(self.out_buffer): + if self.initiate_send() > 0 and len(self.out_buffer): + self._output_drained = False + return + self.close() + # --------------------------------------------------------------------------- # used for debugging. # --------------------------------------------------------------------------- diff --git a/Lib/test/test_asynchat.py b/Lib/test/test_asynchat.py --- a/Lib/test/test_asynchat.py +++ b/Lib/test/test_asynchat.py @@ -15,6 +15,7 @@ HOST = support.HOST SERVER_QUIT = b'QUIT\n' +SERVER_SHUTDOWN = b'SHUTDOWN\n' if threading: class echo_server(threading.Thread): @@ -36,12 +37,18 @@ self.event.set() conn, client = self.sock.accept() self.buffer = b"" + shutdown = False # collect data until quit message is seen while SERVER_QUIT not in self.buffer: data = conn.recv(1) if not data: break self.buffer = self.buffer + data + # perform a half-duplex close + if not shutdown and SERVER_SHUTDOWN in self.buffer: + shutdown = True + self.buffer = self.buffer.replace(SERVER_SHUTDOWN, b'') + conn.shutdown(socket.SHUT_WR) # remove the SERVER_QUIT message self.buffer = self.buffer.replace(SERVER_QUIT, b'') @@ -53,7 +60,7 @@ try: # this may fail on some tests, such as test_close_when_done, since # the client closes the channel when it's done sending - while self.buffer: + while not shutdown and self.buffer: n = conn.send(self.buffer[:self.chunk_size]) time.sleep(0.001) self.buffer = self.buffer[n:] @@ -234,6 +241,22 @@ # (which could still result in the client not having received anything) self.assertGreater(len(s.buffer), 0) + def test_half_duplex_close(self): + # Check that the whole data is received by the echo_server after a + # half-duplex close (issue #12498). + + s, event = start_echo_server() + c = echo_client(b'\n', s.port) + + c.push(SERVER_SHUTDOWN) + data = b' ' * 4096 + b'\n' + c.push(data) + c.push(SERVER_QUIT) + asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01) + s.join() + + self.assertEqual(len(s.buffer), len(data)) + class TestAsynchat_WithPoll(TestAsynchat): usepoll = True diff --git a/Lib/test/test_asyncore.py b/Lib/test/test_asyncore.py --- a/Lib/test/test_asyncore.py +++ b/Lib/test/test_asyncore.py @@ -664,6 +664,52 @@ client = TestClient(self.family, server.address) self.loop_waiting_for_flag(client) + def test_half_duplex_close(self): + # Check that the output buffer is received by the client after a + # half-duplex close, and check that the client gets a close event when + # all data has been received (issue #12498). + + # The test is valid when the size of the data is large enough to ensure + # that some of it has not yet been sent by TestHandler when it gets the + # close event after the clients calls shutdown. + data = b'\0' * 4096 + data_len = len(data) + + class TestClient(BaseClient): + + def __init__(self, family, address): + BaseClient.__init__(self, family, address) + self.shutdown = False + self.buffer = b'' + self.received_all = False + + def handle_read(self): + if not self.shutdown: + self.shutdown = True + self.socket.shutdown(socket.SHUT_WR) + chunk = self.recv(1024) + self.buffer += chunk + if len(self.buffer) == data_len: + self.received_all = True + + def handle_close(self): + if self.received_all: + self.flag = True + self.close() + + class TestHandler(asyncore.dispatcher_with_send): + + def __init__(self, sock): + asyncore.dispatcher_with_send.__init__(self, sock) + self.send(data) + + def handle_read(self): + self.recv(1) + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + @unittest.skipIf(sys.platform.startswith("sunos"), "OOB support is broken on Solaris") def test_handle_expt(self):