bpo-33540, socketserver: Add _block_on_close for tests (GH-7317)
* Add a private _block_on_close attribute to ForkingMixIn and ThreadingMixIn classes of socketserver. * Use _block_on_close=True in test_socketserver and test_logging
This commit is contained in:
parent
5dbb48aaac
commit
1381bfe977
|
@ -547,8 +547,10 @@ if hasattr(os, "fork"):
|
||||||
timeout = 300
|
timeout = 300
|
||||||
active_children = None
|
active_children = None
|
||||||
max_children = 40
|
max_children = 40
|
||||||
|
# If true, server_close() waits until all child processes complete.
|
||||||
|
_block_on_close = False
|
||||||
|
|
||||||
def collect_children(self):
|
def collect_children(self, *, blocking=False):
|
||||||
"""Internal routine to wait for children that have exited."""
|
"""Internal routine to wait for children that have exited."""
|
||||||
if self.active_children is None:
|
if self.active_children is None:
|
||||||
return
|
return
|
||||||
|
@ -572,7 +574,8 @@ if hasattr(os, "fork"):
|
||||||
# Now reap all defunct children.
|
# Now reap all defunct children.
|
||||||
for pid in self.active_children.copy():
|
for pid in self.active_children.copy():
|
||||||
try:
|
try:
|
||||||
pid, _ = os.waitpid(pid, os.WNOHANG)
|
flags = 0 if blocking else os.WNOHANG
|
||||||
|
pid, _ = os.waitpid(pid, flags)
|
||||||
# if the child hasn't exited yet, pid will be 0 and ignored by
|
# if the child hasn't exited yet, pid will be 0 and ignored by
|
||||||
# discard() below
|
# discard() below
|
||||||
self.active_children.discard(pid)
|
self.active_children.discard(pid)
|
||||||
|
@ -621,6 +624,10 @@ if hasattr(os, "fork"):
|
||||||
finally:
|
finally:
|
||||||
os._exit(status)
|
os._exit(status)
|
||||||
|
|
||||||
|
def server_close(self):
|
||||||
|
super().server_close()
|
||||||
|
self.collect_children(blocking=self._block_on_close)
|
||||||
|
|
||||||
|
|
||||||
class ThreadingMixIn:
|
class ThreadingMixIn:
|
||||||
"""Mix-in class to handle each request in a new thread."""
|
"""Mix-in class to handle each request in a new thread."""
|
||||||
|
@ -628,6 +635,11 @@ class ThreadingMixIn:
|
||||||
# Decides how threads will act upon termination of the
|
# Decides how threads will act upon termination of the
|
||||||
# main process
|
# main process
|
||||||
daemon_threads = False
|
daemon_threads = False
|
||||||
|
# If true, server_close() waits until all non-daemonic threads terminate.
|
||||||
|
_block_on_close = False
|
||||||
|
# For non-daemonic threads, list of threading.Threading objects
|
||||||
|
# used by server_close() to wait for all threads completion.
|
||||||
|
_threads = None
|
||||||
|
|
||||||
def process_request_thread(self, request, client_address):
|
def process_request_thread(self, request, client_address):
|
||||||
"""Same as in BaseServer but as a thread.
|
"""Same as in BaseServer but as a thread.
|
||||||
|
@ -647,8 +659,21 @@ class ThreadingMixIn:
|
||||||
t = threading.Thread(target = self.process_request_thread,
|
t = threading.Thread(target = self.process_request_thread,
|
||||||
args = (request, client_address))
|
args = (request, client_address))
|
||||||
t.daemon = self.daemon_threads
|
t.daemon = self.daemon_threads
|
||||||
|
if not t.daemon and self._block_on_close:
|
||||||
|
if self._threads is None:
|
||||||
|
self._threads = []
|
||||||
|
self._threads.append(t)
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
|
def server_close(self):
|
||||||
|
super().server_close()
|
||||||
|
if self._block_on_close:
|
||||||
|
threads = self._threads
|
||||||
|
self._threads = None
|
||||||
|
if threads:
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
|
||||||
if hasattr(os, "fork"):
|
if hasattr(os, "fork"):
|
||||||
class ForkingUDPServer(ForkingMixIn, UDPServer): pass
|
class ForkingUDPServer(ForkingMixIn, UDPServer): pass
|
||||||
|
|
|
@ -883,6 +883,7 @@ if threading:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
allow_reuse_address = True
|
allow_reuse_address = True
|
||||||
|
_block_on_close = True
|
||||||
|
|
||||||
def __init__(self, addr, handler, poll_interval=0.5,
|
def __init__(self, addr, handler, poll_interval=0.5,
|
||||||
bind_and_activate=True):
|
bind_and_activate=True):
|
||||||
|
@ -915,6 +916,8 @@ if threading:
|
||||||
before calling :meth:`start`, so that the server will
|
before calling :meth:`start`, so that the server will
|
||||||
set up the socket and listen on it.
|
set up the socket and listen on it.
|
||||||
"""
|
"""
|
||||||
|
_block_on_close = True
|
||||||
|
|
||||||
def __init__(self, addr, handler, poll_interval=0.5,
|
def __init__(self, addr, handler, poll_interval=0.5,
|
||||||
bind_and_activate=True):
|
bind_and_activate=True):
|
||||||
class DelegatingUDPRequestHandler(DatagramRequestHandler):
|
class DelegatingUDPRequestHandler(DatagramRequestHandler):
|
||||||
|
@ -1474,11 +1477,11 @@ class SocketHandlerTest(BaseTest):
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
"""Shutdown the TCP server."""
|
"""Shutdown the TCP server."""
|
||||||
try:
|
try:
|
||||||
if self.server:
|
|
||||||
self.server.stop(2.0)
|
|
||||||
if self.sock_hdlr:
|
if self.sock_hdlr:
|
||||||
self.root_logger.removeHandler(self.sock_hdlr)
|
self.root_logger.removeHandler(self.sock_hdlr)
|
||||||
self.sock_hdlr.close()
|
self.sock_hdlr.close()
|
||||||
|
if self.server:
|
||||||
|
self.server.stop(2.0)
|
||||||
finally:
|
finally:
|
||||||
BaseTest.tearDown(self)
|
BaseTest.tearDown(self)
|
||||||
|
|
||||||
|
|
|
@ -48,11 +48,11 @@ def receive(sock, n, timeout=20):
|
||||||
if HAVE_UNIX_SOCKETS and HAVE_FORKING:
|
if HAVE_UNIX_SOCKETS and HAVE_FORKING:
|
||||||
class ForkingUnixStreamServer(socketserver.ForkingMixIn,
|
class ForkingUnixStreamServer(socketserver.ForkingMixIn,
|
||||||
socketserver.UnixStreamServer):
|
socketserver.UnixStreamServer):
|
||||||
pass
|
_block_on_close = True
|
||||||
|
|
||||||
class ForkingUnixDatagramServer(socketserver.ForkingMixIn,
|
class ForkingUnixDatagramServer(socketserver.ForkingMixIn,
|
||||||
socketserver.UnixDatagramServer):
|
socketserver.UnixDatagramServer):
|
||||||
pass
|
_block_on_close = True
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
|
@ -62,24 +62,14 @@ def simple_subprocess(testcase):
|
||||||
if pid == 0:
|
if pid == 0:
|
||||||
# Don't raise an exception; it would be caught by the test harness.
|
# Don't raise an exception; it would be caught by the test harness.
|
||||||
os._exit(72)
|
os._exit(72)
|
||||||
yield None
|
try:
|
||||||
pid2, status = os.waitpid(pid, 0)
|
yield None
|
||||||
testcase.assertEqual(pid2, pid)
|
except:
|
||||||
testcase.assertEqual(72 << 8, status)
|
raise
|
||||||
|
finally:
|
||||||
|
pid2, status = os.waitpid(pid, 0)
|
||||||
def close_server(server):
|
testcase.assertEqual(pid2, pid)
|
||||||
server.server_close()
|
testcase.assertEqual(72 << 8, status)
|
||||||
|
|
||||||
if hasattr(server, 'active_children'):
|
|
||||||
# ForkingMixIn: Manually reap all child processes, since server_close()
|
|
||||||
# calls waitpid() in non-blocking mode using the WNOHANG flag.
|
|
||||||
for pid in server.active_children.copy():
|
|
||||||
try:
|
|
||||||
os.waitpid(pid, 0)
|
|
||||||
except ChildProcessError:
|
|
||||||
pass
|
|
||||||
server.active_children.clear()
|
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipUnless(threading, 'Threading required for this test.')
|
@unittest.skipUnless(threading, 'Threading required for this test.')
|
||||||
|
@ -115,6 +105,8 @@ class SocketServerTest(unittest.TestCase):
|
||||||
|
|
||||||
def make_server(self, addr, svrcls, hdlrbase):
|
def make_server(self, addr, svrcls, hdlrbase):
|
||||||
class MyServer(svrcls):
|
class MyServer(svrcls):
|
||||||
|
_block_on_close = True
|
||||||
|
|
||||||
def handle_error(self, request, client_address):
|
def handle_error(self, request, client_address):
|
||||||
self.close_request(request)
|
self.close_request(request)
|
||||||
raise
|
raise
|
||||||
|
@ -156,8 +148,12 @@ class SocketServerTest(unittest.TestCase):
|
||||||
if verbose: print("waiting for server")
|
if verbose: print("waiting for server")
|
||||||
server.shutdown()
|
server.shutdown()
|
||||||
t.join()
|
t.join()
|
||||||
close_server(server)
|
server.server_close()
|
||||||
self.assertEqual(-1, server.socket.fileno())
|
self.assertEqual(-1, server.socket.fileno())
|
||||||
|
if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn):
|
||||||
|
# bpo-31151: Check that ForkingMixIn.server_close() waits until
|
||||||
|
# all children completed
|
||||||
|
self.assertFalse(server.active_children)
|
||||||
if verbose: print("done")
|
if verbose: print("done")
|
||||||
|
|
||||||
def stream_examine(self, proto, addr):
|
def stream_examine(self, proto, addr):
|
||||||
|
@ -280,7 +276,7 @@ class SocketServerTest(unittest.TestCase):
|
||||||
s.shutdown()
|
s.shutdown()
|
||||||
for t, s in threads:
|
for t, s in threads:
|
||||||
t.join()
|
t.join()
|
||||||
close_server(s)
|
s.server_close()
|
||||||
|
|
||||||
def test_tcpserver_bind_leak(self):
|
def test_tcpserver_bind_leak(self):
|
||||||
# Issue #22435: the server socket wouldn't be closed if bind()/listen()
|
# Issue #22435: the server socket wouldn't be closed if bind()/listen()
|
||||||
|
@ -344,6 +340,8 @@ class ErrorHandlerTest(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class BaseErrorTestServer(socketserver.TCPServer):
|
class BaseErrorTestServer(socketserver.TCPServer):
|
||||||
|
_block_on_close = True
|
||||||
|
|
||||||
def __init__(self, exception):
|
def __init__(self, exception):
|
||||||
self.exception = exception
|
self.exception = exception
|
||||||
super().__init__((HOST, 0), BadHandler)
|
super().__init__((HOST, 0), BadHandler)
|
||||||
|
@ -352,7 +350,7 @@ class BaseErrorTestServer(socketserver.TCPServer):
|
||||||
try:
|
try:
|
||||||
self.handle_request()
|
self.handle_request()
|
||||||
finally:
|
finally:
|
||||||
close_server(self)
|
self.server_close()
|
||||||
self.wait_done()
|
self.wait_done()
|
||||||
|
|
||||||
def handle_error(self, request, client_address):
|
def handle_error(self, request, client_address):
|
||||||
|
@ -386,7 +384,7 @@ class ThreadingErrorTestServer(socketserver.ThreadingMixIn,
|
||||||
|
|
||||||
if HAVE_FORKING:
|
if HAVE_FORKING:
|
||||||
class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer):
|
class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer):
|
||||||
pass
|
_block_on_close = True
|
||||||
|
|
||||||
|
|
||||||
class SocketWriterTest(unittest.TestCase):
|
class SocketWriterTest(unittest.TestCase):
|
||||||
|
@ -398,7 +396,7 @@ class SocketWriterTest(unittest.TestCase):
|
||||||
self.server.request_fileno = self.request.fileno()
|
self.server.request_fileno = self.request.fileno()
|
||||||
|
|
||||||
server = socketserver.TCPServer((HOST, 0), Handler)
|
server = socketserver.TCPServer((HOST, 0), Handler)
|
||||||
self.addCleanup(close_server, server)
|
self.addCleanup(server.server_close)
|
||||||
s = socket.socket(
|
s = socket.socket(
|
||||||
server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
|
server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
|
||||||
with s:
|
with s:
|
||||||
|
@ -422,7 +420,7 @@ class SocketWriterTest(unittest.TestCase):
|
||||||
self.server.sent2 = self.wfile.write(big_chunk)
|
self.server.sent2 = self.wfile.write(big_chunk)
|
||||||
|
|
||||||
server = socketserver.TCPServer((HOST, 0), Handler)
|
server = socketserver.TCPServer((HOST, 0), Handler)
|
||||||
self.addCleanup(close_server, server)
|
self.addCleanup(server.server_close)
|
||||||
interrupted = threading.Event()
|
interrupted = threading.Event()
|
||||||
|
|
||||||
def signal_handler(signum, frame):
|
def signal_handler(signum, frame):
|
||||||
|
@ -498,7 +496,7 @@ class MiscTestCase(unittest.TestCase):
|
||||||
s.close()
|
s.close()
|
||||||
server.handle_request()
|
server.handle_request()
|
||||||
self.assertEqual(server.shutdown_called, 1)
|
self.assertEqual(server.shutdown_called, 1)
|
||||||
close_server(server)
|
server.server_close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue