Index: Lib/SocketServer.py =================================================================== --- Lib/SocketServer.py (revision 61289) +++ Lib/SocketServer.py (working copy) @@ -196,9 +196,19 @@ """Constructor. May be extended, do not override.""" self.server_address = server_address self.RequestHandlerClass = RequestHandlerClass - self.__is_shut_down = threading.Event() - self.__serving = False + # Protects __shutdown_requests and __is_serving, and signals + # when shutdown has finished. + self.__condition = threading.Condition(threading.Lock()) + # The number of shutdown calls that haven't killed a + # serve_forever loop yet. Generally it'll be -1 (when a loop + # has terminated before the matching .shutdown call), 0 or 1 + # unless you're doing tricky stuff. + self.__shutdown_requests = 0 + # True while serve_forever is looping. When this is False, + # shutdown can return immediately. + self.__is_serving = False + def server_activate(self): """Called by constructor to activate the server. @@ -212,29 +222,65 @@ Polls for shutdown every poll_interval seconds. Ignores self.timeout. If you need to do periodic tasks, do them in - another thread. + another thread. If serve_forever throws an exception, the next + shutdown will apply to this loop: they apply in matched pairs. """ - self.__serving = True - self.__is_shut_down.clear() - while self.__serving: - # XXX: Consider using another file descriptor or - # connecting to the socket to wake this up instead of - # polling. Polling reduces our responsiveness to a - # shutdown request and wastes cpu at all other times. - r, w, e = select.select([self], [], [], poll_interval) - if r: - self._handle_request_noblock() - self.__is_shut_down.set() + with self.__condition: + assert not self.__is_serving, ( + 'Only one serve_forever loop can be running at a time.') + if self.__shutdown_requests > 0: + # A shutdown was called that didn't kill a loop. Let + # it kill this one. + self.__shutdown_requests -= 1 + return + # Serve until the next shutdown request. + self.__is_serving = True + try: + while True: + with self.__condition: + if self.__shutdown_requests > 0: break + # XXX: Consider using another file descriptor or + # connecting to the socket to wake this up instead of + # polling. Polling reduces our responsiveness to a + # shutdown request and wastes cpu at all other times. + r, w, e = select.select([self], [], [], poll_interval) + if r: + self._handle_request_noblock() + finally: + with self.__condition: + self.__is_serving = False + # Even if we're exiting via an exception, there's + # likely to be a shutdown call that wants to kill this + # particular loop. Pretend it succeeded. + self.__shutdown_requests -= 1 + # All the shutdown requests can continue now. Even if + # their matching serve_forever loop hasn't happened + # yet, it's guaranteed never to accept a connection by + # the guard on __shutdown_requests. + self.__condition.notifyAll() def shutdown(self): """Stops the serve_forever loop. - Blocks until the loop has finished. This must be called while - serve_forever() is running in another thread, or it will - deadlock. + Blocks until the loop is guaranteed not to accept any more + connections. + + If this is called before serve_forever starts looping, it will + return immediately and kill the next one to start. If it's + called after the loop terminates via an exception, it will + just return immediately, unless another loop has started up in + the meantime, in which case it will wait until that one + finishes. This behavior is designed to support repeated + serve_forever..shutdown cycles. """ - self.__serving = False - self.__is_shut_down.wait() + with self.__condition: + # All increments here are matched by a decrement in serve_forever. + self.__shutdown_requests += 1 + # If the loop hasn't started yet, this will return + # immediately, rather than uselessly blocking until it + # tries to start. + while self.__is_serving: + self.__condition.wait() # The distinction between handling, getting, processing and # finishing a request is fairly arbitrary. Remember: Index: Lib/test/test_socketserver.py =================================================================== --- Lib/test/test_socketserver.py (revision 61289) +++ Lib/test/test_socketserver.py (working copy) @@ -49,6 +49,14 @@ pass +def EchoHandler(hdlrbase): + class MyHandler(hdlrbase): + def handle(self): + line = self.rfile.readline() + self.wfile.write(line) + return MyHandler + + @contextlib.contextmanager def simple_subprocess(testcase): pid = os.fork() @@ -112,13 +120,8 @@ self.server_close() raise - class MyHandler(hdlrbase): - def handle(self): - line = self.rfile.readline() - self.wfile.write(line) - if verbose: print "creating server" - server = MyServer(addr, MyHandler) + server = MyServer(addr, EchoHandler(hdlrbase)) self.assertEquals(server.server_address, server.socket.getsockname()) return server @@ -243,7 +246,87 @@ # SocketServer.DatagramRequestHandler, # self.dgram_examine) + def assertServesOne(self, server): + """Checks that the server can accept a request. + Assumes that the server's handler is an EchoHandler. + """ + t = threading.Thread(target=server.serve_forever, + kwargs={'poll_interval': 0.01}) + t.start() + try: + self.stream_examine(server.address_family, server.server_address) + finally: + server.shutdown() + t.join() + + def test_shutdown_makes_next_serve_forever_return(self): + server = SocketServer.TCPServer( + ('', 0), + EchoHandler(SocketServer.StreamRequestHandler)) + server.shutdown() + server.serve_forever() # Should return immediately. + self.assertServesOne(server) + + def test_shutdown_doesnt_deadlock_if_serve_forever_raised(self): + server = SocketServer.TCPServer( + ('', 0), + EchoHandler(SocketServer.StreamRequestHandler)) + old_fileno = server.fileno + server.fileno = lambda *args: "not a fd" + self.assertRaises(TypeError, server.serve_forever) + server.shutdown() # Should return immediately. + server.fileno = old_fileno + self.assertServesOne(server) + + def test_can_serve_forever_repeatedly(self): + # In case people are calling serve_forever in a loop, catching + # the exceptions, it should keep working. + server = SocketServer.TCPServer( + ('', 0), + EchoHandler(SocketServer.StreamRequestHandler)) + old_fileno = server.fileno + server.fileno = lambda *args: "not a fd" + self.assertRaises(TypeError, server.serve_forever) + self.assertRaises(TypeError, server.serve_forever) + server.fileno = old_fileno + + # Another serve_forever call successfully serves, and other + # shutdown calls are blocked while it's serving. + t = threading.Thread(target=server.serve_forever, + kwargs={'poll_interval': 0.01}) + t.setDaemon(True) + t.start() + try: + self.stream_examine(server.address_family, server.server_address) + finally: + try: + s1 = threading.Thread(target=server.shutdown) + s1.start() + def check_shutdown(start, server, done): + start.set() + server.shutdown() + done.set() + start = threading.Event() + done = threading.Event() + s2 = threading.Thread(target=check_shutdown, + args=(start, server, done)) + s2.start() + start.wait() + self.assertFalse(done.isSet()) + server.shutdown() + s2.join() + s1.join() + t.join() + except: + # On error, call shutdown a bunch to make sure the + # serving loop exits. + l = [threading.Thread(target=server.shutdown) + for i in xrange(3)] + for t in l: t.start() + raise + + def test_main(): if imp.lock_held(): # If the import lock is held, the threads will hang