diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py index db6b0d75044..8c1ecce2dda 100644 --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -323,6 +323,82 @@ class ThreadTests(unittest.TestCase): sys.getrefcount(weak_raising_cyclic_object()))) +class ThreadJoinOnShutdown(unittest.TestCase): + + def _run_and_join(self, script): + script = """if 1: + import sys, os, time, threading + + # a thread, which waits for the main program to terminate + def joiningfunc(mainthread): + mainthread.join() + print 'end of thread' + \n""" + script + + import subprocess + p = subprocess.Popen([sys.executable, "-c", script], stdout=subprocess.PIPE) + rc = p.wait() + self.assertEqual(p.stdout.read(), "end of main\nend of thread\n") + self.failIf(rc == 2, "interpreter was blocked") + self.failUnless(rc == 0, "Unexpected error") + + def test_1_join_on_shutdown(self): + # The usual case: on exit, wait for a non-daemon thread + script = """if 1: + import os + t = threading.Thread(target=joiningfunc, + args=(threading.current_thread(),)) + t.start() + time.sleep(0.1) + print 'end of main' + """ + self._run_and_join(script) + + + def test_2_join_in_forked_process(self): + # Like the test above, but from a forked interpreter + import os + if not hasattr(os, 'fork'): + return + script = """if 1: + childpid = os.fork() + if childpid != 0: + os.waitpid(childpid, 0) + sys.exit(0) + + t = threading.Thread(target=joiningfunc, + args=(threading.current_thread(),)) + t.start() + print 'end of main' + """ + self._run_and_join(script) + + def test_3_join_in_forked_from_thread(self): + # Like the test above, but fork() was called from a worker thread + # In the forked process, the main Thread object must be marked as stopped. + import os + if not hasattr(os, 'fork'): + return + script = """if 1: + main_thread = threading.current_thread() + def worker(): + childpid = os.fork() + if childpid != 0: + os.waitpid(childpid, 0) + sys.exit(0) + + t = threading.Thread(target=joiningfunc, + args=(main_thread,)) + print 'end of main' + t.start() + t.join() # Should not block: main_thread is already stopped + + w = threading.Thread(target=worker) + w.start() + """ + self._run_and_join(script) + + class ThreadingExceptionTests(unittest.TestCase): # A RuntimeError should be raised if Thread.start() is called # multiple times. @@ -363,7 +439,9 @@ class ThreadingExceptionTests(unittest.TestCase): def test_main(): test.test_support.run_unittest(ThreadTests, - ThreadingExceptionTests) + ThreadJoinOnShutdown, + ThreadingExceptionTests, + ) if __name__ == "__main__": test_main() diff --git a/Lib/threading.py b/Lib/threading.py index bfca44c065d..8a1de42bea6 100644 --- a/Lib/threading.py +++ b/Lib/threading.py @@ -825,6 +825,37 @@ except ImportError: from _threading_local import local +def _after_fork(): + # This function is called by Python/ceval.c:PyEval_ReInitThreads which + # is called from PyOS_AfterFork. Here we cleanup threading module state + # that should not exist after a fork. + + # Reset _active_limbo_lock, in case we forked while the lock was held + # by another (non-forked) thread. http://bugs.python.org/issue874900 + global _active_limbo_lock + _active_limbo_lock = _allocate_lock() + + # fork() only copied the current thread; clear references to others. + new_active = {} + current = current_thread() + with _active_limbo_lock: + for ident, thread in _active.iteritems(): + if thread is current: + # There is only one active thread. + new_active[ident] = thread + else: + # All the others are already stopped. + # We don't call _Thread__stop() because it tries to acquire + # thread._Thread__block which could also have been held while + # we forked. + thread._Thread__stopped = True + + _limbo.clear() + _active.clear() + _active.update(new_active) + assert len(_active) == 1 + + # Self-test code def _test(): diff --git a/Python/ceval.c b/Python/ceval.c index a9e37ae1fcb..f61bcd51b2a 100644 --- a/Python/ceval.c +++ b/Python/ceval.c @@ -274,6 +274,9 @@ PyEval_ReleaseThread(PyThreadState *tstate) void PyEval_ReInitThreads(void) { + PyObject *threading, *result; + PyThreadState *tstate; + if (!interpreter_lock) return; /*XXX Can't use PyThread_free_lock here because it does too @@ -283,6 +286,23 @@ PyEval_ReInitThreads(void) interpreter_lock = PyThread_allocate_lock(); PyThread_acquire_lock(interpreter_lock, 1); main_thread = PyThread_get_thread_ident(); + + /* Update the threading module with the new state. + */ + tstate = PyThreadState_GET(); + threading = PyMapping_GetItemString(tstate->interp->modules, + "threading"); + if (threading == NULL) { + /* threading not imported */ + PyErr_Clear(); + return; + } + result = PyObject_CallMethod(threading, "_after_fork", NULL); + if (result == NULL) + PyErr_WriteUnraisable(threading); + else + Py_DECREF(result); + Py_DECREF(threading); } #endif