bpo-31233: socketserver.ThreadingMixIn.server_close() (#3523)

socketserver.ThreadingMixIn now keeps a list of non-daemonic threads
to wait until all these threads complete in server_close().

Reenable test_logging skipped tests.

Fix SocketHandlerTest.tearDown(): close the socket handler before
stopping the server, so the server can join threads.
This commit is contained in:
Victor Stinner 2017-09-13 01:47:22 -07:00 committed by GitHub
parent 97d7e65dfe
commit b8f4163da3
3 changed files with 19 additions and 11 deletions

View File

@ -629,6 +629,9 @@ class ThreadingMixIn:
# Decides how threads will act upon termination of the # Decides how threads will act upon termination of the
# main process # main process
daemon_threads = False daemon_threads = False
# For non-daemonic threads, list of threading.Threading objects
# used by server_close() to wait for all threads completion.
_threads = None
def process_request_thread(self, request, client_address): def process_request_thread(self, request, client_address):
"""Same as in BaseServer but as a thread. """Same as in BaseServer but as a thread.
@ -648,8 +651,20 @@ class ThreadingMixIn:
t = threading.Thread(target = self.process_request_thread, t = threading.Thread(target = self.process_request_thread,
args = (request, client_address)) args = (request, client_address))
t.daemon = self.daemon_threads t.daemon = self.daemon_threads
if not t.daemon:
if self._threads is None:
self._threads = []
self._threads.append(t)
t.start() t.start()
def server_close(self):
super().server_close()
threads = self._threads
self._threads = None
if threads:
for thread in threads:
thread.join()
if hasattr(os, "fork"): if hasattr(os, "fork"):
class ForkingUDPServer(ForkingMixIn, UDPServer): pass class ForkingUDPServer(ForkingMixIn, UDPServer): pass

View File

@ -1465,7 +1465,6 @@ class ConfigFileTest(BaseTest):
self.assertFalse(logger.disabled) self.assertFalse(logger.disabled)
@unittest.skipIf(True, "FIXME: bpo-30830")
class SocketHandlerTest(BaseTest): class SocketHandlerTest(BaseTest):
"""Test for SocketHandler objects.""" """Test for SocketHandler objects."""
@ -1502,11 +1501,11 @@ class SocketHandlerTest(BaseTest):
def tearDown(self): def tearDown(self):
"""Shutdown the TCP server.""" """Shutdown the TCP server."""
try: try:
if self.server:
self.server.stop(2.0)
if self.sock_hdlr: if self.sock_hdlr:
self.root_logger.removeHandler(self.sock_hdlr) self.root_logger.removeHandler(self.sock_hdlr)
self.sock_hdlr.close() self.sock_hdlr.close()
if self.server:
self.server.stop(2.0)
finally: finally:
BaseTest.tearDown(self) BaseTest.tearDown(self)
@ -1563,7 +1562,6 @@ def _get_temp_domain_socket():
os.remove(fn) os.remove(fn)
return fn return fn
@unittest.skipIf(True, "FIXME: bpo-30830")
@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required") @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required")
class UnixSocketHandlerTest(SocketHandlerTest): class UnixSocketHandlerTest(SocketHandlerTest):
@ -1581,7 +1579,6 @@ class UnixSocketHandlerTest(SocketHandlerTest):
SocketHandlerTest.tearDown(self) SocketHandlerTest.tearDown(self)
support.unlink(self.address) support.unlink(self.address)
@unittest.skipIf(True, "FIXME: bpo-30830")
class DatagramHandlerTest(BaseTest): class DatagramHandlerTest(BaseTest):
"""Test for DatagramHandler.""" """Test for DatagramHandler."""
@ -1646,7 +1643,6 @@ class DatagramHandlerTest(BaseTest):
self.handled.wait() self.handled.wait()
self.assertEqual(self.log_output, "spam\neggs\n") self.assertEqual(self.log_output, "spam\neggs\n")
@unittest.skipIf(True, "FIXME: bpo-30830")
@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required") @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required")
class UnixDatagramHandlerTest(DatagramHandlerTest): class UnixDatagramHandlerTest(DatagramHandlerTest):
@ -1731,7 +1727,6 @@ class SysLogHandlerTest(BaseTest):
self.handled.wait() self.handled.wait()
self.assertEqual(self.log_output, b'<11>h\xc3\xa4m-sp\xc3\xa4m') self.assertEqual(self.log_output, b'<11>h\xc3\xa4m-sp\xc3\xa4m')
@unittest.skipIf(True, "FIXME: bpo-30830")
@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required") @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required")
class UnixSysLogHandlerTest(SysLogHandlerTest): class UnixSysLogHandlerTest(SysLogHandlerTest):
@ -1749,7 +1744,6 @@ class UnixSysLogHandlerTest(SysLogHandlerTest):
SysLogHandlerTest.tearDown(self) SysLogHandlerTest.tearDown(self)
support.unlink(self.address) support.unlink(self.address)
@unittest.skipIf(True, "FIXME: bpo-30830")
@unittest.skipUnless(support.IPV6_ENABLED, @unittest.skipUnless(support.IPV6_ENABLED,
'IPv6 support required for this test.') 'IPv6 support required for this test.')
class IPv6SysLogHandlerTest(SysLogHandlerTest): class IPv6SysLogHandlerTest(SysLogHandlerTest):
@ -2872,9 +2866,6 @@ class ConfigDictTest(BaseTest):
logging.warning('Exclamation') logging.warning('Exclamation')
self.assertTrue(output.getvalue().endswith('Exclamation!\n')) self.assertTrue(output.getvalue().endswith('Exclamation!\n'))
# listen() uses ConfigSocketReceiver which is based
# on socketserver.ThreadingTCPServer
@unittest.skipIf(True, "FIXME: bpo-30830")
def setup_via_listener(self, text, verify=None): def setup_via_listener(self, text, verify=None):
text = text.encode("utf-8") text = text.encode("utf-8")
# Ask for a randomly assigned port (by using port 0) # Ask for a randomly assigned port (by using port 0)

View File

@ -0,0 +1,2 @@
socketserver.ThreadingMixIn now keeps a list of non-daemonic threads to wait
until all these threads complete in server_close().