bpo-31234: Add support.join_thread() helper (#3587)

join_thread() joins a thread but raises an AssertionError if the
thread is still alive after timeout seconds.
This commit is contained in:
Victor Stinner 2017-09-14 14:40:56 -07:00 committed by GitHub
parent 167cbde50a
commit b9b69003d9
9 changed files with 50 additions and 65 deletions

View File

@ -21,6 +21,7 @@ import operator
import weakref import weakref
import test.support import test.support
import test.support.script_helper import test.support.script_helper
from test import support
# Skip tests if _multiprocessing wasn't built. # Skip tests if _multiprocessing wasn't built.
@ -72,6 +73,12 @@ def close_queue(queue):
queue.join_thread() queue.join_thread()
def join_process(process, timeout):
# Since multiprocessing.Process has the same API than threading.Thread
# (join() and is_alive(), the support function can be reused
support.join_thread(process, timeout)
# #
# Constants # Constants
# #
@ -477,7 +484,7 @@ class _TestProcess(BaseTestCase):
for p in procs: for p in procs:
p.start() p.start()
for p in procs: for p in procs:
p.join(timeout=10) join_process(p, timeout=10)
for p in procs: for p in procs:
self.assertEqual(p.exitcode, 0) self.assertEqual(p.exitcode, 0)
@ -489,7 +496,7 @@ class _TestProcess(BaseTestCase):
for p in procs: for p in procs:
p.terminate() p.terminate()
for p in procs: for p in procs:
p.join(timeout=10) join_process(p, timeout=10)
if os.name != 'nt': if os.name != 'nt':
for p in procs: for p in procs:
self.assertEqual(p.exitcode, -signal.SIGTERM) self.assertEqual(p.exitcode, -signal.SIGTERM)
@ -652,7 +659,7 @@ class _TestSubclassingProcess(BaseTestCase):
p = self.Process(target=self._test_sys_exit, args=(reason, testfn)) p = self.Process(target=self._test_sys_exit, args=(reason, testfn))
p.daemon = True p.daemon = True
p.start() p.start()
p.join(5) join_process(p, timeout=5)
self.assertEqual(p.exitcode, 1) self.assertEqual(p.exitcode, 1)
with open(testfn, 'r') as f: with open(testfn, 'r') as f:
@ -665,7 +672,7 @@ class _TestSubclassingProcess(BaseTestCase):
p = self.Process(target=sys.exit, args=(reason,)) p = self.Process(target=sys.exit, args=(reason,))
p.daemon = True p.daemon = True
p.start() p.start()
p.join(5) join_process(p, timeout=5)
self.assertEqual(p.exitcode, reason) self.assertEqual(p.exitcode, reason)
# #
@ -1254,8 +1261,7 @@ class _TestCondition(BaseTestCase):
state.value += 1 state.value += 1
cond.notify() cond.notify()
p.join(5) join_process(p, timeout=5)
self.assertFalse(p.is_alive())
self.assertEqual(p.exitcode, 0) self.assertEqual(p.exitcode, 0)
@classmethod @classmethod
@ -1291,7 +1297,7 @@ class _TestCondition(BaseTestCase):
state.value += 1 state.value += 1
cond.notify() cond.notify()
p.join(5) join_process(p, timeout=5)
self.assertTrue(success.value) self.assertTrue(success.value)
@classmethod @classmethod
@ -4005,7 +4011,7 @@ class TestTimeouts(unittest.TestCase):
self.assertEqual(conn.recv(), 456) self.assertEqual(conn.recv(), 456)
conn.close() conn.close()
l.close() l.close()
p.join(10) join_process(p, timeout=10)
finally: finally:
socket.setdefaulttimeout(old_timeout) socket.setdefaulttimeout(old_timeout)
@ -4041,7 +4047,7 @@ class TestForkAwareThreadLock(unittest.TestCase):
p = multiprocessing.Process(target=cls.child, args=(n-1, conn)) p = multiprocessing.Process(target=cls.child, args=(n-1, conn))
p.start() p.start()
conn.close() conn.close()
p.join(timeout=5) join_process(p, timeout=5)
else: else:
conn.send(len(util._afterfork_registry)) conn.send(len(util._afterfork_registry))
conn.close() conn.close()
@ -4054,7 +4060,7 @@ class TestForkAwareThreadLock(unittest.TestCase):
p.start() p.start()
w.close() w.close()
new_size = r.recv() new_size = r.recv()
p.join(timeout=5) join_process(p, timeout=5)
self.assertLessEqual(new_size, old_size) self.assertLessEqual(new_size, old_size)
# #
@ -4109,7 +4115,7 @@ class TestCloseFds(unittest.TestCase):
p.start() p.start()
writer.close() writer.close()
e = reader.recv() e = reader.recv()
p.join(timeout=5) join_process(p, timeout=5)
finally: finally:
self.close(fd) self.close(fd)
writer.close() writer.close()

View File

@ -2107,6 +2107,16 @@ def wait_threads_exit(timeout=60.0):
gc_collect() gc_collect()
def join_thread(thread, timeout=30.0):
"""Join a thread. Raise an AssertionError if the thread is still alive
after timeout seconds.
"""
thread.join(timeout)
if thread.is_alive():
msg = f"failed to join the thread in {timeout:.1f} seconds"
raise AssertionError(msg)
def reap_children(): def reap_children():
"""Use this function at the end of test_main() whenever sub-processes """Use this function at the end of test_main() whenever sub-processes
are started. This will help ensure that no extra children (zombies) are started. This will help ensure that no extra children (zombies)

View File

@ -123,9 +123,7 @@ class TestAsynchat(unittest.TestCase):
c.push(b"I'm not dead yet!" + term) c.push(b"I'm not dead yet!" + term)
c.push(SERVER_QUIT) c.push(SERVER_QUIT)
asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01) asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01)
s.join(timeout=TIMEOUT) support.join_thread(s, timeout=TIMEOUT)
if s.is_alive():
self.fail("join() timed out")
self.assertEqual(c.contents, [b"hello world", b"I'm not dead yet!"]) self.assertEqual(c.contents, [b"hello world", b"I'm not dead yet!"])
@ -156,9 +154,7 @@ class TestAsynchat(unittest.TestCase):
c.push(data) c.push(data)
c.push(SERVER_QUIT) c.push(SERVER_QUIT)
asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01) asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01)
s.join(timeout=TIMEOUT) support.join_thread(s, timeout=TIMEOUT)
if s.is_alive():
self.fail("join() timed out")
self.assertEqual(c.contents, [data[:termlen]]) self.assertEqual(c.contents, [data[:termlen]])
@ -178,9 +174,7 @@ class TestAsynchat(unittest.TestCase):
c.push(data) c.push(data)
c.push(SERVER_QUIT) c.push(SERVER_QUIT)
asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01) asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01)
s.join(timeout=TIMEOUT) support.join_thread(s, timeout=TIMEOUT)
if s.is_alive():
self.fail("join() timed out")
self.assertEqual(c.contents, []) self.assertEqual(c.contents, [])
self.assertEqual(c.buffer, data) self.assertEqual(c.buffer, data)
@ -192,9 +186,7 @@ class TestAsynchat(unittest.TestCase):
p = asynchat.simple_producer(data+SERVER_QUIT, buffer_size=8) p = asynchat.simple_producer(data+SERVER_QUIT, buffer_size=8)
c.push_with_producer(p) c.push_with_producer(p)
asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01) asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01)
s.join(timeout=TIMEOUT) support.join_thread(s, timeout=TIMEOUT)
if s.is_alive():
self.fail("join() timed out")
self.assertEqual(c.contents, [b"hello world", b"I'm not dead yet!"]) self.assertEqual(c.contents, [b"hello world", b"I'm not dead yet!"])
@ -204,9 +196,7 @@ class TestAsynchat(unittest.TestCase):
data = b"hello world\nI'm not dead yet!\n" data = b"hello world\nI'm not dead yet!\n"
c.push_with_producer(data+SERVER_QUIT) c.push_with_producer(data+SERVER_QUIT)
asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01) asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01)
s.join(timeout=TIMEOUT) support.join_thread(s, timeout=TIMEOUT)
if s.is_alive():
self.fail("join() timed out")
self.assertEqual(c.contents, [b"hello world", b"I'm not dead yet!"]) self.assertEqual(c.contents, [b"hello world", b"I'm not dead yet!"])
@ -217,9 +207,7 @@ class TestAsynchat(unittest.TestCase):
c.push(b"hello world\n\nI'm not dead yet!\n") c.push(b"hello world\n\nI'm not dead yet!\n")
c.push(SERVER_QUIT) c.push(SERVER_QUIT)
asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01) asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01)
s.join(timeout=TIMEOUT) support.join_thread(s, timeout=TIMEOUT)
if s.is_alive():
self.fail("join() timed out")
self.assertEqual(c.contents, self.assertEqual(c.contents,
[b"hello world", b"", b"I'm not dead yet!"]) [b"hello world", b"", b"I'm not dead yet!"])
@ -238,9 +226,7 @@ class TestAsynchat(unittest.TestCase):
# where the server echoes all of its data before we can check that it # where the server echoes all of its data before we can check that it
# got any down below. # got any down below.
s.start_resend_event.set() s.start_resend_event.set()
s.join(timeout=TIMEOUT) support.join_thread(s, timeout=TIMEOUT)
if s.is_alive():
self.fail("join() timed out")
self.assertEqual(c.contents, []) self.assertEqual(c.contents, [])
# the server might have been able to send a byte or two back, but this # the server might have been able to send a byte or two back, but this
@ -261,7 +247,7 @@ class TestAsynchat(unittest.TestCase):
self.assertRaises(TypeError, c.push, 'unicode') self.assertRaises(TypeError, c.push, 'unicode')
c.push(SERVER_QUIT) c.push(SERVER_QUIT)
asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01) asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01)
s.join(timeout=TIMEOUT) support.join_thread(s, timeout=TIMEOUT)
self.assertEqual(c.contents, [b'bytes', b'bytes', b'bytes']) self.assertEqual(c.contents, [b'bytes', b'bytes', b'bytes'])

View File

@ -808,7 +808,7 @@ class EventLoopTestsMixin:
proto.transport.close() proto.transport.close()
lsock.close() lsock.close()
thread.join(1) support.join_thread(thread, timeout=1)
self.assertFalse(thread.is_alive()) self.assertFalse(thread.is_alive())
self.assertEqual(proto.state, 'CLOSED') self.assertEqual(proto.state, 'CLOSED')
self.assertEqual(proto.nbytes, len(message)) self.assertEqual(proto.nbytes, len(message))

View File

@ -360,9 +360,7 @@ class DispatcherWithSendTests(unittest.TestCase):
self.assertEqual(cap.getvalue(), data*2) self.assertEqual(cap.getvalue(), data*2)
finally: finally:
t.join(timeout=TIMEOUT) support.join_thread(t, timeout=TIMEOUT)
if t.is_alive():
self.fail("join() timed out")
@unittest.skipUnless(hasattr(asyncore, 'file_wrapper'), @unittest.skipUnless(hasattr(asyncore, 'file_wrapper'),
@ -794,9 +792,7 @@ class BaseTestAPI:
except OSError: except OSError:
pass pass
finally: finally:
t.join(timeout=TIMEOUT) support.join_thread(t, timeout=TIMEOUT)
if t.is_alive():
self.fail("join() timed out")
class TestAPI_UseIPv4Sockets(BaseTestAPI): class TestAPI_UseIPv4Sockets(BaseTestAPI):
family = socket.AF_INET family = socket.AF_INET

View File

@ -220,7 +220,9 @@ class NewIMAPTestsMixin():
# cleanup the server # cleanup the server
self.server.shutdown() self.server.shutdown()
self.server.server_close() self.server.server_close()
self.thread.join(3.0) support.join_thread(self.thread, 3.0)
# Explicitly clear the attribute to prevent dangling thread
self.thread = None
def test_EOF_without_complete_welcome_message(self): def test_EOF_without_complete_welcome_message(self):
# http://bugs.python.org/issue5949 # http://bugs.python.org/issue5949

View File

@ -791,13 +791,10 @@ class TestSMTPServer(smtpd.SMTPServer):
to terminate. to terminate.
""" """
self.close() self.close()
self._thread.join(timeout) support.join_thread(self._thread, timeout)
self._thread = None
asyncore.close_all(map=self._map, ignore_all=True) asyncore.close_all(map=self._map, ignore_all=True)
alive = self._thread.is_alive()
self._thread = None
if alive:
self.fail("join() timed out")
class ControlMixin(object): class ControlMixin(object):
""" """
@ -847,11 +844,8 @@ class ControlMixin(object):
""" """
self.shutdown() self.shutdown()
if self._thread is not None: if self._thread is not None:
self._thread.join(timeout) support.join_thread(self._thread, timeout)
alive = self._thread.is_alive()
self._thread = None self._thread = None
if alive:
self.fail("join() timed out")
self.server_close() self.server_close()
self.ready.clear() self.ready.clear()
@ -2892,9 +2886,7 @@ class ConfigDictTest(BaseTest):
finally: finally:
t.ready.wait(2.0) t.ready.wait(2.0)
logging.config.stopListening() logging.config.stopListening()
t.join(2.0) support.join_thread(t, 2.0)
if t.is_alive():
self.fail("join() timed out")
def test_listen_config_10_ok(self): def test_listen_config_10_ok(self):
with support.captured_stdout() as output: with support.captured_stdout() as output:

View File

@ -58,10 +58,7 @@ class BlockingTestMixin:
block_func) block_func)
return self.result return self.result
finally: finally:
thread.join(10) # make sure the thread terminates support.join_thread(thread, 10) # make sure the thread terminates
if thread.is_alive():
self.fail("trigger function '%r' appeared to not return" %
trigger_func)
# Call this instead if block_func is supposed to raise an exception. # Call this instead if block_func is supposed to raise an exception.
def do_exceptional_blocking_test(self,block_func, block_args, trigger_func, def do_exceptional_blocking_test(self,block_func, block_args, trigger_func,
@ -77,10 +74,7 @@ class BlockingTestMixin:
self.fail("expected exception of kind %r" % self.fail("expected exception of kind %r" %
expected_exception_class) expected_exception_class)
finally: finally:
thread.join(10) # make sure the thread terminates support.join_thread(thread, 10) # make sure the thread terminates
if thread.is_alive():
self.fail("trigger function '%r' appeared to not return" %
trigger_func)
if not thread.startedEvent.is_set(): if not thread.startedEvent.is_set():
self.fail("trigger thread ended but event never set") self.fail("trigger thread ended but event never set")

View File

@ -3,6 +3,7 @@ import sched
import threading import threading
import time import time
import unittest import unittest
from test import support
TIMEOUT = 10 TIMEOUT = 10
@ -81,8 +82,7 @@ class TestCase(unittest.TestCase):
self.assertEqual(q.get(timeout=TIMEOUT), 5) self.assertEqual(q.get(timeout=TIMEOUT), 5)
self.assertTrue(q.empty()) self.assertTrue(q.empty())
timer.advance(1000) timer.advance(1000)
t.join(timeout=TIMEOUT) support.join_thread(t, timeout=TIMEOUT)
self.assertFalse(t.is_alive())
self.assertTrue(q.empty()) self.assertTrue(q.empty())
self.assertEqual(timer.time(), 5) self.assertEqual(timer.time(), 5)
@ -137,8 +137,7 @@ class TestCase(unittest.TestCase):
self.assertEqual(q.get(timeout=TIMEOUT), 4) self.assertEqual(q.get(timeout=TIMEOUT), 4)
self.assertTrue(q.empty()) self.assertTrue(q.empty())
timer.advance(1000) timer.advance(1000)
t.join(timeout=TIMEOUT) support.join_thread(t, timeout=TIMEOUT)
self.assertFalse(t.is_alive())
self.assertTrue(q.empty()) self.assertTrue(q.empty())
self.assertEqual(timer.time(), 4) self.assertEqual(timer.time(), 4)