From 392c159ad6a0a622c5d78b6783039e836829a6d9 Mon Sep 17 00:00:00 2001 From: Jeffrey Yasskin Date: Thu, 28 Feb 2008 18:03:15 +0000 Subject: [PATCH] Prevent SocketServer.ForkingMixIn from waiting on child processes that it didn't create, in most cases. When there are max_children handlers running, it will still wait for any child process, not just handler processes. --- Lib/SocketServer.py | 30 ++++++++++++++++++++--------- Lib/test/test_socketserver.py | 36 +++++++++++++++++++++++++---------- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/Lib/SocketServer.py b/Lib/SocketServer.py index 1763c1d5e47..0a194e5de88 100644 --- a/Lib/SocketServer.py +++ b/Lib/SocketServer.py @@ -440,18 +440,30 @@ class ForkingMixIn: def collect_children(self): """Internal routine to wait for children that have exited.""" - while self.active_children: - if len(self.active_children) < self.max_children: - options = os.WNOHANG - else: - # If the maximum number of children are already - # running, block while waiting for a child to exit - options = 0 + if self.active_children is None: return + while len(self.active_children) >= self.max_children: + # XXX: This will wait for any child process, not just ones + # spawned by this library. This could confuse other + # libraries that expect to be able to wait for their own + # children. try: - pid, status = os.waitpid(0, options) + pid, status = os.waitpid(0, options=0) except os.error: pid = None - if not pid: break + if pid not in self.active_children: continue + self.active_children.remove(pid) + + # XXX: This loop runs more system calls than it ought + # to. There should be a way to put the active_children into a + # process group and then use os.waitpid(-pgid) to wait for any + # of that set, but I couldn't find a way to allocate pgids + # that couldn't collide. + for child in self.active_children: + try: + pid, status = os.waitpid(child, os.WNOHANG) + except os.error: + pid = None + if not pid: continue try: self.active_children.remove(pid) except ValueError, e: diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py index 4a70e59ba4f..98a4c1f9afd 100644 --- a/Lib/test/test_socketserver.py +++ b/Lib/test/test_socketserver.py @@ -2,6 +2,7 @@ Test suite for SocketServer.py. """ +import contextlib import errno import imp import os @@ -82,6 +83,18 @@ class ServerThread(threading.Thread): if verbose: print "thread: done" +@contextlib.contextmanager +def simple_subprocess(testcase): + pid = os.fork() + if pid == 0: + # Don't throw an exception; it would be caught by the test harness. + os._exit(72) + yield None + pid2, status = os.waitpid(pid, 0) + testcase.assertEquals(pid2, pid) + testcase.assertEquals(72 << 8, status) + + class SocketServerTest(unittest.TestCase): """Test all socket servers.""" @@ -183,10 +196,11 @@ class SocketServerTest(unittest.TestCase): self.stream_examine) if HAVE_FORKING: - def test_ThreadingTCPServer(self): - self.run_server(SocketServer.ForkingTCPServer, - SocketServer.StreamRequestHandler, - self.stream_examine) + def test_ForkingTCPServer(self): + with simple_subprocess(self): + self.run_server(SocketServer.ForkingTCPServer, + SocketServer.StreamRequestHandler, + self.stream_examine) if HAVE_UNIX_SOCKETS: def test_UnixStreamServer(self): @@ -201,9 +215,10 @@ class SocketServerTest(unittest.TestCase): if HAVE_FORKING: def test_ForkingUnixStreamServer(self): - self.run_server(ForkingUnixStreamServer, - SocketServer.StreamRequestHandler, - self.stream_examine) + with simple_subprocess(self): + self.run_server(ForkingUnixStreamServer, + SocketServer.StreamRequestHandler, + self.stream_examine) def test_UDPServer(self): self.run_server(SocketServer.UDPServer, @@ -217,9 +232,10 @@ class SocketServerTest(unittest.TestCase): if HAVE_FORKING: def test_ForkingUDPServer(self): - self.run_server(SocketServer.ForkingUDPServer, - SocketServer.DatagramRequestHandler, - self.dgram_examine) + with simple_subprocess(self): + self.run_server(SocketServer.ForkingUDPServer, + SocketServer.DatagramRequestHandler, + self.dgram_examine) # Alas, on Linux (at least) recvfrom() doesn't return a meaningful # client address so this cannot work: