diff --git a/Lib/asynchat.py b/Lib/asynchat.py --- a/Lib/asynchat.py +++ b/Lib/asynchat.py @@ -194,6 +194,11 @@ self.initiate_send() def handle_close (self): + if not self._closing: + self._closing = True + # try to drain the output buffer + while self.writable() and self.initiate_send() > 0: + pass self.close() def push (self, data): @@ -219,7 +224,8 @@ def writable (self): "predicate for inclusion in the writable for select()" - return self.producer_fifo or (not self.connected) + return (not self.connected + or (self.producer_fifo and not self._hangup)) def close_when_done (self): "automatically close this channel once the outgoing queue is empty" @@ -234,7 +240,7 @@ if first is None: ## print("first is None") self.handle_close() - return + return 0 ## print("first is not None") # handle classic producer behavior @@ -257,7 +263,7 @@ num_sent = self.send(data) except socket.error: self.handle_error() - return + return 0 if num_sent: if num_sent < len(data) or obs < len(first): @@ -265,7 +271,7 @@ else: del self.producer_fifo[0] # we tried to send some actual data - return + return num_sent def discard_buffers (self): # Emergencies only! diff --git a/Lib/asyncore.py b/Lib/asyncore.py --- a/Lib/asyncore.py +++ b/Lib/asyncore.py @@ -110,7 +110,9 @@ obj.handle_write_event() if flags & select.POLLPRI: obj.handle_expt_event() - if flags & (select.POLLHUP | select.POLLERR | select.POLLNVAL): + if flags & select.POLLHUP: + obj.handle_hangup_evt() + if flags & (select.POLLERR | select.POLLNVAL): obj.handle_close() except socket.error as e: if e.args[0] not in _DISCONNECTED: @@ -220,7 +222,8 @@ debug = False connected = False accepting = False - closing = False + _closing = False + _hangup = False addr = None ignore_log_types = frozenset(['warning']) @@ -319,7 +322,9 @@ return True def writable(self): - return True + # POLLHUP and POLLOUT are mutually-exclusive; a stream can never be + # writable if a hangup has occurred. + return not self._hangup # ================================================== # socket object methods. @@ -475,6 +480,9 @@ else: self.handle_expt() + def handle_hangup_evt(self): + self._hangup = True + def handle_error(self): nil, t, v, tbinfo = compact_traceback() @@ -532,15 +540,16 @@ self.out_buffer = b'' def initiate_send(self): - num_sent = 0 num_sent = dispatcher.send(self, self.out_buffer[:512]) self.out_buffer = self.out_buffer[num_sent:] + return num_sent def handle_write(self): self.initiate_send() def writable(self): - return (not self.connected) or len(self.out_buffer) + return (not self.connected + or (len(self.out_buffer) and not self._hangup)) def send(self, data): if self.debug: @@ -548,6 +557,14 @@ self.out_buffer = self.out_buffer + data self.initiate_send() + def handle_close(self): + if not self._closing: + self._closing = True + # try to drain the output buffer + while self.writable() and self.initiate_send() > 0: + pass + self.close() + # --------------------------------------------------------------------------- # used for debugging. # --------------------------------------------------------------------------- diff --git a/Lib/test/test_asynchat.py b/Lib/test/test_asynchat.py --- a/Lib/test/test_asynchat.py +++ b/Lib/test/test_asynchat.py @@ -15,6 +15,7 @@ HOST = support.HOST SERVER_QUIT = b'QUIT\n' +SERVER_SHUTDOWN = b'SHUTDOWN\n' if threading: class echo_server(threading.Thread): @@ -36,12 +37,18 @@ self.event.set() conn, client = self.sock.accept() self.buffer = b"" + shutdown = False # collect data until quit message is seen while SERVER_QUIT not in self.buffer: data = conn.recv(1) if not data: break self.buffer = self.buffer + data + # perform a half-duplex close + if not shutdown and SERVER_SHUTDOWN in self.buffer: + shutdown = True + self.buffer = self.buffer.replace(SERVER_SHUTDOWN, b'') + conn.shutdown(socket.SHUT_WR) # remove the SERVER_QUIT message self.buffer = self.buffer.replace(SERVER_QUIT, b'') @@ -53,7 +60,7 @@ try: # this may fail on some tests, such as test_close_when_done, since # the client closes the channel when it's done sending - while self.buffer: + while not shutdown and self.buffer: n = conn.send(self.buffer[:self.chunk_size]) time.sleep(0.001) self.buffer = self.buffer[n:] @@ -234,6 +241,22 @@ # (which could still result in the client not having received anything) self.assertGreater(len(s.buffer), 0) + def test_half_duplex_close(self): + # Check that the whole data is received by the echo_server after a + # half-duplex close (issue #12498). + + s, event = start_echo_server() + c = echo_client(b'\n', s.port) + + c.push(SERVER_SHUTDOWN) + data = b' ' * 4096 + b'\n' + c.push(data) + c.push(SERVER_QUIT) + asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01) + s.join() + + self.assertEqual(len(s.buffer), len(data)) + class TestAsynchat_WithPoll(TestAsynchat): usepoll = True diff --git a/Lib/test/test_asyncore.py b/Lib/test/test_asyncore.py --- a/Lib/test/test_asyncore.py +++ b/Lib/test/test_asyncore.py @@ -49,6 +49,7 @@ handle_write_event = handle_read_event handle_close = handle_read_event handle_expt_event = handle_read_event + handle_hangup_evt = handle_read_event class crashingdummy: def __init__(self): @@ -131,14 +132,15 @@ def test_readwrite(self): # Check that correct methods are called by readwrite() - attributes = ('read', 'expt', 'write', 'closed', 'error_handled') + attributes = ('read', 'expt', 'write', 'closed', 'error_handled', + 'hangup') expected = ( (select.POLLIN, 'read'), (select.POLLPRI, 'expt'), (select.POLLOUT, 'write'), (select.POLLERR, 'closed'), - (select.POLLHUP, 'closed'), + (select.POLLHUP, 'hangup'), (select.POLLNVAL, 'closed'), ) @@ -148,6 +150,7 @@ self.write = False self.closed = False self.expt = False + self.hangup = False self.error_handled = False def handle_read_event(self): @@ -162,6 +165,9 @@ def handle_expt_event(self): self.expt = True + def handle_hangup_evt(self): + self.hangup = True + def handle_error(self): self.error_handled = True @@ -664,6 +670,52 @@ client = TestClient(self.family, server.address) self.loop_waiting_for_flag(client) + def test_half_duplex_close(self): + # Check that the output buffer is received by the client after a + # half-duplex close, and check that the client gets a close event when + # all data has been received (issue #12498). + + # The test is valid when the size of the data is large enough to ensure + # that some of it has not yet been sent by TestHandler when it gets the + # close event after the clients calls shutdown. + data = b'\0' * 4096 + data_len = len(data) + + class TestClient(BaseClient): + + def __init__(self, family, address): + BaseClient.__init__(self, family, address) + self.shutdown = False + self.buffer = b'' + self.received_all = False + + def handle_read(self): + if not self.shutdown: + self.shutdown = True + self.socket.shutdown(socket.SHUT_WR) + chunk = self.recv(1024) + self.buffer += chunk + if len(self.buffer) == data_len: + self.received_all = True + + def handle_close(self): + if self.received_all: + self.flag = True + self.close() + + class TestHandler(asyncore.dispatcher_with_send): + + def __init__(self, sock): + asyncore.dispatcher_with_send.__init__(self, sock) + self.send(data) + + def handle_read(self): + self.recv(1) + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + @unittest.skipIf(sys.platform.startswith("sunos"), "OOB support is broken on Solaris") def test_handle_expt(self):