diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index 191908239ce..82321af79e1 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -6,6 +6,7 @@ if __name__ != 'test.support': import collections.abc import contextlib import errno +import faulthandler import fnmatch import functools import gc @@ -96,7 +97,7 @@ __all__ = [ # logging "TestHandler", # threads - "threading_setup", "threading_cleanup", + "threading_setup", "threading_cleanup", "reap_threads", "start_threads", # miscellaneous "check_warnings", "EnvironmentVarGuard", "run_with_locale", "swap_item", "swap_attr", "Matcher", "set_memlimit", "SuppressCrashReport", "sortdict", @@ -1941,6 +1942,42 @@ def reap_children(): except: break +@contextlib.contextmanager +def start_threads(threads, unlock=None): + threads = list(threads) + started = [] + try: + try: + for t in threads: + t.start() + started.append(t) + except: + if verbose: + print("Can't start %d threads, only %d threads started" % + (len(threads), len(started))) + raise + yield + finally: + try: + if unlock: + unlock() + endtime = starttime = time.time() + for timeout in range(1, 16): + endtime += 60 + for t in started: + t.join(max(endtime - time.time(), 0.01)) + started = [t for t in started if t.isAlive()] + if not started: + break + if verbose: + print('Unable to join %d threads during a period of ' + '%d minutes' % (len(started), timeout)) + finally: + started = [t for t in started if t.isAlive()] + if started: + faulthandler.dump_traceback(sys.stdout) + raise AssertionError('Unable to join %d threads' % len(started)) + @contextlib.contextmanager def swap_attr(obj, attr, new_val): """Temporary swap out an attribute with a new object. diff --git a/Lib/test/test_bz2.py b/Lib/test/test_bz2.py index 2fe3f2a1976..bf9887b45a0 100644 --- a/Lib/test/test_bz2.py +++ b/Lib/test/test_bz2.py @@ -507,10 +507,8 @@ class BZ2FileTest(BaseTest): for i in range(5): f.write(data) threads = [threading.Thread(target=comp) for i in range(nthreads)] - for t in threads: - t.start() - for t in threads: - t.join() + with support.start_threads(threads): + pass def testWithoutThreading(self): module = support.import_fresh_module("bz2", blocked=("threading",)) diff --git a/Lib/test/test_capi.py b/Lib/test/test_capi.py index fbd1466c788..dff717b5e05 100644 --- a/Lib/test/test_capi.py +++ b/Lib/test/test_capi.py @@ -284,15 +284,11 @@ class TestPendingCalls(unittest.TestCase): context.lock = threading.Lock() context.event = threading.Event() - for i in range(context.nThreads): - t = threading.Thread(target=self.pendingcalls_thread, args = (context,)) - t.start() - threads.append(t) - - self.pendingcalls_wait(context.l, n, context) - - for t in threads: - t.join() + threads = [threading.Thread(target=self.pendingcalls_thread, + args=(context,)) + for i in range(context.nThreads)] + with support.start_threads(threads): + self.pendingcalls_wait(context.l, n, context) def pendingcalls_thread(self, context): try: diff --git a/Lib/test/test_gc.py b/Lib/test/test_gc.py index c025512790c..2ac1d4bb642 100644 --- a/Lib/test/test_gc.py +++ b/Lib/test/test_gc.py @@ -1,6 +1,6 @@ import unittest from test.support import (verbose, refcount_test, run_unittest, - strip_python_stderr, cpython_only) + strip_python_stderr, cpython_only, start_threads) from test.script_helper import assert_python_ok, make_script, temp_dir import sys @@ -397,19 +397,13 @@ class GCTests(unittest.TestCase): old_switchinterval = sys.getswitchinterval() sys.setswitchinterval(1e-5) try: - exit = False + exit = [] threads = [] for i in range(N_THREADS): t = threading.Thread(target=run_thread) threads.append(t) - try: - for t in threads: - t.start() - finally: + with start_threads(threads, lambda: exit.append(1)): time.sleep(1.0) - exit = True - for t in threads: - t.join() finally: sys.setswitchinterval(old_switchinterval) gc.collect() diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py index 87547795eac..9fb85f21e2c 100644 --- a/Lib/test/test_io.py +++ b/Lib/test/test_io.py @@ -1131,11 +1131,8 @@ class BufferedReaderTest(unittest.TestCase, CommonBufferedTests): errors.append(e) raise threads = [threading.Thread(target=f) for x in range(20)] - for t in threads: - t.start() - time.sleep(0.02) # yield - for t in threads: - t.join() + with support.start_threads(threads): + time.sleep(0.02) # yield self.assertFalse(errors, "the following exceptions were caught: %r" % errors) s = b''.join(results) @@ -1454,11 +1451,8 @@ class BufferedWriterTest(unittest.TestCase, CommonBufferedTests): errors.append(e) raise threads = [threading.Thread(target=f) for x in range(20)] - for t in threads: - t.start() - time.sleep(0.02) # yield - for t in threads: - t.join() + with support.start_threads(threads): + time.sleep(0.02) # yield self.assertFalse(errors, "the following exceptions were caught: %r" % errors) bufio.close() @@ -2752,14 +2746,10 @@ class TextIOWrapperTest(unittest.TestCase): text = "Thread%03d\n" % n event.wait() f.write(text) - threads = [threading.Thread(target=lambda n=x: run(n)) + threads = [threading.Thread(target=run, args=(x,)) for x in range(20)] - for t in threads: - t.start() - time.sleep(0.02) - event.set() - for t in threads: - t.join() + with support.start_threads(threads, event.set): + time.sleep(0.02) with self.open(support.TESTFN) as f: content = f.read() for n in range(20): @@ -3493,7 +3483,7 @@ class SignalsTest(unittest.TestCase): self.assertRaises(ZeroDivisionError, wio.write, large_data) finally: signal.alarm(0) - t.join() + t.join() # We got one byte, get another one and check that it isn't a # repeat of the first one. read_results.append(os.read(r, 1)) diff --git a/Lib/test/test_threaded_import.py b/Lib/test/test_threaded_import.py index 5bf670c8157..9b2d9a6f19b 100644 --- a/Lib/test/test_threaded_import.py +++ b/Lib/test/test_threaded_import.py @@ -14,7 +14,7 @@ import shutil import unittest from test.support import ( verbose, import_module, run_unittest, TESTFN, reap_threads, - forget, unlink, rmtree) + forget, unlink, rmtree, start_threads) threading = import_module('threading') def task(N, done, done_tasks, errors): @@ -116,10 +116,10 @@ class ThreadedImportTests(unittest.TestCase): done_tasks = [] done.clear() t0 = time.monotonic() - for i in range(N): - t = threading.Thread(target=task, - args=(N, done, done_tasks, errors,)) - t.start() + with start_threads(threading.Thread(target=task, + args=(N, done, done_tasks, errors,)) + for i in range(N)): + pass completed = done.wait(10 * 60) dt = time.monotonic() - t0 if verbose: diff --git a/Lib/test/test_threadedtempfile.py b/Lib/test/test_threadedtempfile.py index 2dfd3a08db8..b7420360eae 100644 --- a/Lib/test/test_threadedtempfile.py +++ b/Lib/test/test_threadedtempfile.py @@ -18,7 +18,7 @@ FILES_PER_THREAD = 50 import tempfile -from test.support import threading_setup, threading_cleanup, run_unittest, import_module +from test.support import start_threads, import_module threading = import_module('threading') import unittest import io @@ -46,33 +46,17 @@ class TempFileGreedy(threading.Thread): class ThreadedTempFileTest(unittest.TestCase): def test_main(self): - threads = [] - thread_info = threading_setup() - - for i in range(NUM_THREADS): - t = TempFileGreedy() - threads.append(t) - t.start() - - startEvent.set() - - ok = 0 - errors = [] - for t in threads: - t.join() - ok += t.ok_count - if t.error_count: - errors.append(str(t.name) + str(t.errors.getvalue())) - - threading_cleanup(*thread_info) + threads = [TempFileGreedy() for i in range(NUM_THREADS)] + with start_threads(threads, startEvent.set): + pass + ok = sum(t.ok_count for t in threads) + errors = [str(t.name) + str(t.errors.getvalue()) + for t in threads if t.error_count] msg = "Errors: errors %d ok %d\n%s" % (len(errors), ok, '\n'.join(errors)) self.assertEqual(errors, [], msg) self.assertEqual(ok, NUM_THREADS * FILES_PER_THREAD) -def test_main(): - run_unittest(ThreadedTempFileTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_threading_local.py b/Lib/test/test_threading_local.py index c886a25d8ab..c7f394cf60b 100644 --- a/Lib/test/test_threading_local.py +++ b/Lib/test/test_threading_local.py @@ -64,14 +64,9 @@ class BaseLocalTest: # Simply check that the variable is correctly set self.assertEqual(local.x, i) - threads= [] - for i in range(10): - t = threading.Thread(target=f, args=(i,)) - t.start() - threads.append(t) - - for t in threads: - t.join() + with support.start_threads(threading.Thread(target=f, args=(i,)) + for i in range(10)): + pass def test_derived_cycle_dealloc(self): # http://bugs.python.org/issue6990 diff --git a/Misc/NEWS b/Misc/NEWS index 76b7ef7a2f6..6b63fdc5530 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -230,6 +230,9 @@ Build Tests ----- +- Issue #23799: Added test.support.start_threads() for running and + cleaning up multiple threads. + - Issue #22390: test.regrtest now emits a warning if temporary files or directories are left after running a test.