bpo-33540, socketserver: Add _block_on_close for tests (GH-7317)

* Add a private _block_on_close attribute to ForkingMixIn and
  ThreadingMixIn classes of socketserver.
* Use _block_on_close=True in test_socketserver and test_logging
This commit is contained in:
Victor Stinner 2018-06-01 16:24:43 +02:00 committed by GitHub
parent 5dbb48aaac
commit 1381bfe977
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 57 additions and 31 deletions

View File

@ -547,8 +547,10 @@ if hasattr(os, "fork"):
timeout = 300 timeout = 300
active_children = None active_children = None
max_children = 40 max_children = 40
# If true, server_close() waits until all child processes complete.
_block_on_close = False
def collect_children(self): def collect_children(self, *, blocking=False):
"""Internal routine to wait for children that have exited.""" """Internal routine to wait for children that have exited."""
if self.active_children is None: if self.active_children is None:
return return
@ -572,7 +574,8 @@ if hasattr(os, "fork"):
# Now reap all defunct children. # Now reap all defunct children.
for pid in self.active_children.copy(): for pid in self.active_children.copy():
try: try:
pid, _ = os.waitpid(pid, os.WNOHANG) flags = 0 if blocking else os.WNOHANG
pid, _ = os.waitpid(pid, flags)
# if the child hasn't exited yet, pid will be 0 and ignored by # if the child hasn't exited yet, pid will be 0 and ignored by
# discard() below # discard() below
self.active_children.discard(pid) self.active_children.discard(pid)
@ -621,6 +624,10 @@ if hasattr(os, "fork"):
finally: finally:
os._exit(status) os._exit(status)
def server_close(self):
super().server_close()
self.collect_children(blocking=self._block_on_close)
class ThreadingMixIn: class ThreadingMixIn:
"""Mix-in class to handle each request in a new thread.""" """Mix-in class to handle each request in a new thread."""
@ -628,6 +635,11 @@ class ThreadingMixIn:
# Decides how threads will act upon termination of the # Decides how threads will act upon termination of the
# main process # main process
daemon_threads = False daemon_threads = False
# If true, server_close() waits until all non-daemonic threads terminate.
_block_on_close = False
# For non-daemonic threads, list of threading.Threading objects
# used by server_close() to wait for all threads completion.
_threads = None
def process_request_thread(self, request, client_address): def process_request_thread(self, request, client_address):
"""Same as in BaseServer but as a thread. """Same as in BaseServer but as a thread.
@ -647,8 +659,21 @@ class ThreadingMixIn:
t = threading.Thread(target = self.process_request_thread, t = threading.Thread(target = self.process_request_thread,
args = (request, client_address)) args = (request, client_address))
t.daemon = self.daemon_threads t.daemon = self.daemon_threads
if not t.daemon and self._block_on_close:
if self._threads is None:
self._threads = []
self._threads.append(t)
t.start() t.start()
def server_close(self):
super().server_close()
if self._block_on_close:
threads = self._threads
self._threads = None
if threads:
for thread in threads:
thread.join()
if hasattr(os, "fork"): if hasattr(os, "fork"):
class ForkingUDPServer(ForkingMixIn, UDPServer): pass class ForkingUDPServer(ForkingMixIn, UDPServer): pass

View File

@ -883,6 +883,7 @@ if threading:
""" """
allow_reuse_address = True allow_reuse_address = True
_block_on_close = True
def __init__(self, addr, handler, poll_interval=0.5, def __init__(self, addr, handler, poll_interval=0.5,
bind_and_activate=True): bind_and_activate=True):
@ -915,6 +916,8 @@ if threading:
before calling :meth:`start`, so that the server will before calling :meth:`start`, so that the server will
set up the socket and listen on it. set up the socket and listen on it.
""" """
_block_on_close = True
def __init__(self, addr, handler, poll_interval=0.5, def __init__(self, addr, handler, poll_interval=0.5,
bind_and_activate=True): bind_and_activate=True):
class DelegatingUDPRequestHandler(DatagramRequestHandler): class DelegatingUDPRequestHandler(DatagramRequestHandler):
@ -1474,11 +1477,11 @@ class SocketHandlerTest(BaseTest):
def tearDown(self): def tearDown(self):
"""Shutdown the TCP server.""" """Shutdown the TCP server."""
try: try:
if self.server:
self.server.stop(2.0)
if self.sock_hdlr: if self.sock_hdlr:
self.root_logger.removeHandler(self.sock_hdlr) self.root_logger.removeHandler(self.sock_hdlr)
self.sock_hdlr.close() self.sock_hdlr.close()
if self.server:
self.server.stop(2.0)
finally: finally:
BaseTest.tearDown(self) BaseTest.tearDown(self)

View File

@ -48,11 +48,11 @@ def receive(sock, n, timeout=20):
if HAVE_UNIX_SOCKETS and HAVE_FORKING: if HAVE_UNIX_SOCKETS and HAVE_FORKING:
class ForkingUnixStreamServer(socketserver.ForkingMixIn, class ForkingUnixStreamServer(socketserver.ForkingMixIn,
socketserver.UnixStreamServer): socketserver.UnixStreamServer):
pass _block_on_close = True
class ForkingUnixDatagramServer(socketserver.ForkingMixIn, class ForkingUnixDatagramServer(socketserver.ForkingMixIn,
socketserver.UnixDatagramServer): socketserver.UnixDatagramServer):
pass _block_on_close = True
@contextlib.contextmanager @contextlib.contextmanager
@ -62,24 +62,14 @@ def simple_subprocess(testcase):
if pid == 0: if pid == 0:
# Don't raise an exception; it would be caught by the test harness. # Don't raise an exception; it would be caught by the test harness.
os._exit(72) os._exit(72)
yield None try:
pid2, status = os.waitpid(pid, 0) yield None
testcase.assertEqual(pid2, pid) except:
testcase.assertEqual(72 << 8, status) raise
finally:
pid2, status = os.waitpid(pid, 0)
def close_server(server): testcase.assertEqual(pid2, pid)
server.server_close() testcase.assertEqual(72 << 8, status)
if hasattr(server, 'active_children'):
# ForkingMixIn: Manually reap all child processes, since server_close()
# calls waitpid() in non-blocking mode using the WNOHANG flag.
for pid in server.active_children.copy():
try:
os.waitpid(pid, 0)
except ChildProcessError:
pass
server.active_children.clear()
@unittest.skipUnless(threading, 'Threading required for this test.') @unittest.skipUnless(threading, 'Threading required for this test.')
@ -115,6 +105,8 @@ class SocketServerTest(unittest.TestCase):
def make_server(self, addr, svrcls, hdlrbase): def make_server(self, addr, svrcls, hdlrbase):
class MyServer(svrcls): class MyServer(svrcls):
_block_on_close = True
def handle_error(self, request, client_address): def handle_error(self, request, client_address):
self.close_request(request) self.close_request(request)
raise raise
@ -156,8 +148,12 @@ class SocketServerTest(unittest.TestCase):
if verbose: print("waiting for server") if verbose: print("waiting for server")
server.shutdown() server.shutdown()
t.join() t.join()
close_server(server) server.server_close()
self.assertEqual(-1, server.socket.fileno()) self.assertEqual(-1, server.socket.fileno())
if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn):
# bpo-31151: Check that ForkingMixIn.server_close() waits until
# all children completed
self.assertFalse(server.active_children)
if verbose: print("done") if verbose: print("done")
def stream_examine(self, proto, addr): def stream_examine(self, proto, addr):
@ -280,7 +276,7 @@ class SocketServerTest(unittest.TestCase):
s.shutdown() s.shutdown()
for t, s in threads: for t, s in threads:
t.join() t.join()
close_server(s) s.server_close()
def test_tcpserver_bind_leak(self): def test_tcpserver_bind_leak(self):
# Issue #22435: the server socket wouldn't be closed if bind()/listen() # Issue #22435: the server socket wouldn't be closed if bind()/listen()
@ -344,6 +340,8 @@ class ErrorHandlerTest(unittest.TestCase):
class BaseErrorTestServer(socketserver.TCPServer): class BaseErrorTestServer(socketserver.TCPServer):
_block_on_close = True
def __init__(self, exception): def __init__(self, exception):
self.exception = exception self.exception = exception
super().__init__((HOST, 0), BadHandler) super().__init__((HOST, 0), BadHandler)
@ -352,7 +350,7 @@ class BaseErrorTestServer(socketserver.TCPServer):
try: try:
self.handle_request() self.handle_request()
finally: finally:
close_server(self) self.server_close()
self.wait_done() self.wait_done()
def handle_error(self, request, client_address): def handle_error(self, request, client_address):
@ -386,7 +384,7 @@ class ThreadingErrorTestServer(socketserver.ThreadingMixIn,
if HAVE_FORKING: if HAVE_FORKING:
class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer): class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer):
pass _block_on_close = True
class SocketWriterTest(unittest.TestCase): class SocketWriterTest(unittest.TestCase):
@ -398,7 +396,7 @@ class SocketWriterTest(unittest.TestCase):
self.server.request_fileno = self.request.fileno() self.server.request_fileno = self.request.fileno()
server = socketserver.TCPServer((HOST, 0), Handler) server = socketserver.TCPServer((HOST, 0), Handler)
self.addCleanup(close_server, server) self.addCleanup(server.server_close)
s = socket.socket( s = socket.socket(
server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP) server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
with s: with s:
@ -422,7 +420,7 @@ class SocketWriterTest(unittest.TestCase):
self.server.sent2 = self.wfile.write(big_chunk) self.server.sent2 = self.wfile.write(big_chunk)
server = socketserver.TCPServer((HOST, 0), Handler) server = socketserver.TCPServer((HOST, 0), Handler)
self.addCleanup(close_server, server) self.addCleanup(server.server_close)
interrupted = threading.Event() interrupted = threading.Event()
def signal_handler(signum, frame): def signal_handler(signum, frame):
@ -498,7 +496,7 @@ class MiscTestCase(unittest.TestCase):
s.close() s.close()
server.handle_request() server.handle_request()
self.assertEqual(server.shutdown_called, 1) self.assertEqual(server.shutdown_called, 1)
close_server(server) server.server_close()
if __name__ == "__main__": if __name__ == "__main__":