Index: Lib/SocketServer.py =================================================================== --- Lib/SocketServer.py (revision 61169) +++ Lib/SocketServer.py (working copy) @@ -130,8 +130,13 @@ import socket +import select import sys import os +try: + import threading +except ImportError: + import dummy_threading as threading __all__ = ["TCPServer","UDPServer","ForkingUDPServer","ForkingTCPServer", "ThreadingUDPServer","ThreadingTCPServer","BaseRequestHandler", @@ -190,6 +195,8 @@ """Constructor. May be extended, do not override.""" self.server_address = server_address self.RequestHandlerClass = RequestHandlerClass + self.__is_shut_down = threading.Event() + self.__serving = False def server_activate(self): """Called by constructor to activate the server. @@ -199,27 +206,66 @@ """ pass - def serve_forever(self): - """Handle one request at a time until doomsday.""" - while 1: - self.handle_request() + def serve_forever(self, poll_interval=0.5): + """Handle one request at a time until shutdown. + Polls for shutdown every poll_interval seconds. Ignores + self.timeout. If you need to do periodic tasks, do them in + another thread. + """ + self.__serving = True + self.__is_shut_down.clear() + while self.__serving: + # XXX: Consider using another file descriptor or + # connecting to the socket to wake this up instead of + # polling. Polling reduces our responsiveness to a + # shutdown request and wastes cpu at all other times. + r, w, e = select.select([self], [], [], poll_interval) + if r: + self._handle_request_noblock() + self.__is_shut_down.set() + + def shutdown(self): + """Stops the serve_forever loop. + + Blocks until the loop has finished. This must be called while + serve_forever() is running in another thread, or it will + deadlock. + """ + self.__serving = False + self.__is_shut_down.wait() + # The distinction between handling, getting, processing and # finishing a request is fairly arbitrary. Remember: # # - handle_request() is the top-level call. It calls - # await_request(), verify_request() and process_request() - # - get_request(), called by await_request(), is different for - # stream or datagram sockets + # select, get_request(), verify_request() and process_request() + # - get_request() is different for stream or datagram sockets # - process_request() is the place that may fork a new process # or create a new thread to finish the request # - finish_request() instantiates the request handler class; # this constructor will handle the request all by itself def handle_request(self): - """Handle one request, possibly blocking.""" + """Handle one request, possibly blocking. + + Respects self.timeout. + """ + fd_sets = select.select([self], [], [], self.timeout) + if not fd_sets[0]: + self.handle_timeout() + return + self._handle_request_noblock() + + def _handle_request_noblock(self): + """Handle one request, without blocking. + + I assume that select.select has returned that the socket is + readable before this function was called, so there should be + no risk of blocking in get_request(). + """ try: - request, client_address = self.await_request() + request, client_address = self.get_request() except socket.error: return if self.verify_request(request, client_address): @@ -229,21 +275,6 @@ self.handle_error(request, client_address) self.close_request(request) - def await_request(self): - """Call get_request or handle_timeout, observing self.timeout. - - Returns value from get_request() or raises socket.timeout exception if - timeout was exceeded. - """ - if self.timeout is not None: - # If timeout == 0, you're responsible for your own fd magic. - import select - fd_sets = select.select([self], [], [], self.timeout) - if not fd_sets[0]: - self.handle_timeout() - raise socket.timeout("Listening timed out") - return self.get_request() - def handle_timeout(self): """Called if no new request arrives within self.timeout. @@ -523,7 +554,6 @@ def process_request(self, request, client_address): """Start a new thread to process the request.""" - import threading t = threading.Thread(target = self.process_request_thread, args = (request, client_address)) if self.daemon_threads: Index: Lib/test/test_socketserver.py =================================================================== --- Lib/test/test_socketserver.py (revision 61169) +++ Lib/test/test_socketserver.py (working copy) @@ -21,7 +21,6 @@ test.test_support.requires("network") -NREQ = 3 TEST_STR = "hello world\n" HOST = "localhost" @@ -46,40 +45,32 @@ pass -class MyMixinServer: - def serve_a_few(self): - for i in range(NREQ): - self.handle_request() - - def handle_error(self, request, client_address): - self.close_request(request) - self.server_close() - raise - - class ServerThread(threading.Thread): def __init__(self, addr, svrcls, hdlrcls): threading.Thread.__init__(self) - self.__addr = addr - self.__svrcls = svrcls - self.__hdlrcls = hdlrcls - self.ready = threading.Event() - def run(self): - class svrcls(MyMixinServer, self.__svrcls): - pass - if verbose: print "thread: creating server" - svr = svrcls(self.__addr, self.__hdlrcls) + class Server(svrcls): + def handle_error(self, request, client_address): + self.close_request(request) + self.server_close() + raise + + if verbose: print "creating server" + self.server = Server(addr, hdlrcls) # We had the OS pick a port, so pull the real address out of # the server. - self.addr = svr.server_address - self.port = self.addr[1] - if self.addr != svr.socket.getsockname(): + self.addr = self.server.server_address + if self.addr != self.server.socket.getsockname(): raise RuntimeError('server_address was %s, expected %s' % (self.addr, svr.socket.getsockname())) - self.ready.set() - if verbose: print "thread: serving three times" - svr.serve_a_few() + + def run(self): + if verbose: print "thread: serving until shutdown" + # Main thread will call svr.shutdown(). Small polling + # interval because wasting CPU in a test is better than + # wasting absolute time. + self.server.serve_forever(poll_interval=0.01) + if verbose: print "thread: done" @@ -151,16 +142,15 @@ print "CLASS =", svrcls t = ServerThread(addr, svrcls, MyHandler) if verbose: print "server created" + t.setDaemon(True) # In case this function raises. t.start() if verbose: print "server running" - t.ready.wait(10) - self.assert_(t.ready.isSet(), - "%s not ready within a reasonable time" % svrcls) addr = t.addr - for i in range(NREQ): + for i in range(3): if verbose: print "test client", i testfunc(svrcls.address_family, addr) if verbose: print "waiting for server" + t.server.shutdown() t.join() if verbose: print "done"