diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py index 7c648186f29..e9899c92d90 100644 --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -220,8 +220,85 @@ class ThreadTests(unittest.TestCase): sys.setcheckinterval(old_interval) +class ThreadJoinOnShutdown(unittest.TestCase): + + def _run_and_join(self, script): + script = """if 1: + import sys, os, time, threading + + # a thread, which waits for the main program to terminate + def joiningfunc(mainthread): + mainthread.join() + print 'end of thread' + \n""" + script + + import subprocess + p = subprocess.Popen([sys.executable, "-c", script], stdout=subprocess.PIPE) + rc = p.wait() + self.assertEqual(p.stdout.read(), "end of main\nend of thread\n") + self.failIf(rc == 2, "interpreter was blocked") + self.failUnless(rc == 0, "Unexpected error") + + def test_1_join_on_shutdown(self): + # The usual case: on exit, wait for a non-daemon thread + script = """if 1: + import os + t = threading.Thread(target=joiningfunc, + args=(threading.currentThread(),)) + t.start() + time.sleep(0.1) + print 'end of main' + """ + self._run_and_join(script) + + + def test_2_join_in_forked_process(self): + # Like the test above, but from a forked interpreter + import os + if not hasattr(os, 'fork'): + return + script = """if 1: + childpid = os.fork() + if childpid != 0: + os.waitpid(childpid, 0) + sys.exit(0) + + t = threading.Thread(target=joiningfunc, + args=(threading.currentThread(),)) + t.start() + print 'end of main' + """ + self._run_and_join(script) + + def test_3_join_in_forked_from_thread(self): + # Like the test above, but fork() was called from a worker thread + # In the forked process, the main Thread object must be marked as stopped. + import os + if not hasattr(os, 'fork'): + return + script = """if 1: + main_thread = threading.currentThread() + def worker(): + childpid = os.fork() + if childpid != 0: + os.waitpid(childpid, 0) + sys.exit(0) + + t = threading.Thread(target=joiningfunc, + args=(main_thread,)) + print 'end of main' + t.start() + t.join() # Should not block: main_thread is already stopped + + w = threading.Thread(target=worker) + w.start() + """ + self._run_and_join(script) + + def test_main(): - test.test_support.run_unittest(ThreadTests) + test.test_support.run_unittest(ThreadTests, + ThreadJoinOnShutdown) if __name__ == "__main__": test_main() diff --git a/Lib/threading.py b/Lib/threading.py index bab3b426414..1ecc06eb339 100644 --- a/Lib/threading.py +++ b/Lib/threading.py @@ -766,6 +766,40 @@ except ImportError: from _threading_local import local +def _after_fork(): + # This function is called by Python/ceval.c:PyEval_ReInitThreads which + # is called from PyOS_AfterFork. Here we cleanup threading module state + # that should not exist after a fork. + + # Reset _active_limbo_lock, in case we forked while the lock was held + # by another (non-forked) thread. http://bugs.python.org/issue874900 + global _active_limbo_lock + _active_limbo_lock = _allocate_lock() + + # fork() only copied the current thread; clear references to others. + new_active = {} + current = currentThread() + _active_limbo_lock.acquire() + try: + for ident, thread in _active.iteritems(): + if thread is current: + # There is only one active thread. + new_active[ident] = thread + else: + # All the others are already stopped. + # We don't call _Thread__stop() because it tries to acquire + # thread._Thread__block which could also have been held while + # we forked. + thread._Thread__stopped = True + + _limbo.clear() + _active.clear() + _active.update(new_active) + assert len(_active) == 1 + finally: + _active_limbo_lock.release() + + # Self-test code def _test(): diff --git a/Misc/NEWS b/Misc/NEWS index 4ba3d579c3a..8864aa9849b 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -144,6 +144,10 @@ Library - Issue #2113: Fix error in subprocess.Popen if the select system call is interrupted by a signal. +- Issue #874900: after an os.fork() call the threading module state is cleaned + up in the child process to prevent deadlock and report proper thread counts + if the new process uses the threading module. + Extension Modules ----------------- diff --git a/Python/ceval.c b/Python/ceval.c index 9bc147b78d8..ddfe3c48ead 100644 --- a/Python/ceval.c +++ b/Python/ceval.c @@ -272,6 +272,9 @@ PyEval_ReleaseThread(PyThreadState *tstate) void PyEval_ReInitThreads(void) { + PyObject *threading, *result; + PyThreadState *tstate; + if (!interpreter_lock) return; /*XXX Can't use PyThread_free_lock here because it does too @@ -281,6 +284,23 @@ PyEval_ReInitThreads(void) interpreter_lock = PyThread_allocate_lock(); PyThread_acquire_lock(interpreter_lock, 1); main_thread = PyThread_get_thread_ident(); + + /* Update the threading module with the new state. + */ + tstate = PyThreadState_GET(); + threading = PyMapping_GetItemString(tstate->interp->modules, + "threading"); + if (threading == NULL) { + /* threading not imported */ + PyErr_Clear(); + return; + } + result = PyObject_CallMethod(threading, "_after_fork", NULL); + if (result == NULL) + PyErr_WriteUnraisable(threading); + else + Py_DECREF(result); + Py_DECREF(threading); } #endif