diff --git a/Lib/socketserver.py b/Lib/socketserver.py index 41a37667721..c4d544b372d 100644 --- a/Lib/socketserver.py +++ b/Lib/socketserver.py @@ -547,8 +547,10 @@ if hasattr(os, "fork"): timeout = 300 active_children = None max_children = 40 + # If true, server_close() waits until all child processes complete. + _block_on_close = False - def collect_children(self): + def collect_children(self, *, blocking=False): """Internal routine to wait for children that have exited.""" if self.active_children is None: return @@ -572,7 +574,8 @@ if hasattr(os, "fork"): # Now reap all defunct children. for pid in self.active_children.copy(): try: - pid, _ = os.waitpid(pid, os.WNOHANG) + flags = 0 if blocking else os.WNOHANG + pid, _ = os.waitpid(pid, flags) # if the child hasn't exited yet, pid will be 0 and ignored by # discard() below self.active_children.discard(pid) @@ -621,6 +624,10 @@ if hasattr(os, "fork"): finally: os._exit(status) + def server_close(self): + super().server_close() + self.collect_children(blocking=self._block_on_close) + class ThreadingMixIn: """Mix-in class to handle each request in a new thread.""" @@ -628,6 +635,11 @@ class ThreadingMixIn: # Decides how threads will act upon termination of the # main process daemon_threads = False + # If true, server_close() waits until all non-daemonic threads terminate. + _block_on_close = False + # For non-daemonic threads, list of threading.Threading objects + # used by server_close() to wait for all threads completion. + _threads = None def process_request_thread(self, request, client_address): """Same as in BaseServer but as a thread. @@ -647,8 +659,21 @@ class ThreadingMixIn: t = threading.Thread(target = self.process_request_thread, args = (request, client_address)) t.daemon = self.daemon_threads + if not t.daemon and self._block_on_close: + if self._threads is None: + self._threads = [] + self._threads.append(t) t.start() + def server_close(self): + super().server_close() + if self._block_on_close: + threads = self._threads + self._threads = None + if threads: + for thread in threads: + thread.join() + if hasattr(os, "fork"): class ForkingUDPServer(ForkingMixIn, UDPServer): pass diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index fc067138c3b..d341ef8779b 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -883,6 +883,7 @@ if threading: """ allow_reuse_address = True + _block_on_close = True def __init__(self, addr, handler, poll_interval=0.5, bind_and_activate=True): @@ -915,6 +916,8 @@ if threading: before calling :meth:`start`, so that the server will set up the socket and listen on it. """ + _block_on_close = True + def __init__(self, addr, handler, poll_interval=0.5, bind_and_activate=True): class DelegatingUDPRequestHandler(DatagramRequestHandler): @@ -1474,11 +1477,11 @@ class SocketHandlerTest(BaseTest): def tearDown(self): """Shutdown the TCP server.""" try: - if self.server: - self.server.stop(2.0) if self.sock_hdlr: self.root_logger.removeHandler(self.sock_hdlr) self.sock_hdlr.close() + if self.server: + self.server.stop(2.0) finally: BaseTest.tearDown(self) diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py index 43621337e03..8177c417877 100644 --- a/Lib/test/test_socketserver.py +++ b/Lib/test/test_socketserver.py @@ -48,11 +48,11 @@ def receive(sock, n, timeout=20): if HAVE_UNIX_SOCKETS and HAVE_FORKING: class ForkingUnixStreamServer(socketserver.ForkingMixIn, socketserver.UnixStreamServer): - pass + _block_on_close = True class ForkingUnixDatagramServer(socketserver.ForkingMixIn, socketserver.UnixDatagramServer): - pass + _block_on_close = True @contextlib.contextmanager @@ -62,24 +62,14 @@ def simple_subprocess(testcase): if pid == 0: # Don't raise an exception; it would be caught by the test harness. os._exit(72) - yield None - pid2, status = os.waitpid(pid, 0) - testcase.assertEqual(pid2, pid) - testcase.assertEqual(72 << 8, status) - - -def close_server(server): - server.server_close() - - if hasattr(server, 'active_children'): - # ForkingMixIn: Manually reap all child processes, since server_close() - # calls waitpid() in non-blocking mode using the WNOHANG flag. - for pid in server.active_children.copy(): - try: - os.waitpid(pid, 0) - except ChildProcessError: - pass - server.active_children.clear() + try: + yield None + except: + raise + finally: + pid2, status = os.waitpid(pid, 0) + testcase.assertEqual(pid2, pid) + testcase.assertEqual(72 << 8, status) @unittest.skipUnless(threading, 'Threading required for this test.') @@ -115,6 +105,8 @@ class SocketServerTest(unittest.TestCase): def make_server(self, addr, svrcls, hdlrbase): class MyServer(svrcls): + _block_on_close = True + def handle_error(self, request, client_address): self.close_request(request) raise @@ -156,8 +148,12 @@ class SocketServerTest(unittest.TestCase): if verbose: print("waiting for server") server.shutdown() t.join() - close_server(server) + server.server_close() self.assertEqual(-1, server.socket.fileno()) + if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn): + # bpo-31151: Check that ForkingMixIn.server_close() waits until + # all children completed + self.assertFalse(server.active_children) if verbose: print("done") def stream_examine(self, proto, addr): @@ -280,7 +276,7 @@ class SocketServerTest(unittest.TestCase): s.shutdown() for t, s in threads: t.join() - close_server(s) + s.server_close() def test_tcpserver_bind_leak(self): # Issue #22435: the server socket wouldn't be closed if bind()/listen() @@ -344,6 +340,8 @@ class ErrorHandlerTest(unittest.TestCase): class BaseErrorTestServer(socketserver.TCPServer): + _block_on_close = True + def __init__(self, exception): self.exception = exception super().__init__((HOST, 0), BadHandler) @@ -352,7 +350,7 @@ class BaseErrorTestServer(socketserver.TCPServer): try: self.handle_request() finally: - close_server(self) + self.server_close() self.wait_done() def handle_error(self, request, client_address): @@ -386,7 +384,7 @@ class ThreadingErrorTestServer(socketserver.ThreadingMixIn, if HAVE_FORKING: class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer): - pass + _block_on_close = True class SocketWriterTest(unittest.TestCase): @@ -398,7 +396,7 @@ class SocketWriterTest(unittest.TestCase): self.server.request_fileno = self.request.fileno() server = socketserver.TCPServer((HOST, 0), Handler) - self.addCleanup(close_server, server) + self.addCleanup(server.server_close) s = socket.socket( server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP) with s: @@ -422,7 +420,7 @@ class SocketWriterTest(unittest.TestCase): self.server.sent2 = self.wfile.write(big_chunk) server = socketserver.TCPServer((HOST, 0), Handler) - self.addCleanup(close_server, server) + self.addCleanup(server.server_close) interrupted = threading.Event() def signal_handler(signum, frame): @@ -498,7 +496,7 @@ class MiscTestCase(unittest.TestCase): s.close() server.handle_request() self.assertEqual(server.shutdown_called, 1) - close_server(server) + server.server_close() if __name__ == "__main__":