Apply patch for 874900: threading module can deadlock after fork

This commit is contained in:
Jesse Noller 2008-07-16 20:03:47 +00:00
parent 1bbf4ea553
commit 5e62ca4fea
3 changed files with 130 additions and 1 deletions

View File

@ -323,6 +323,82 @@ class ThreadTests(unittest.TestCase):
sys.getrefcount(weak_raising_cyclic_object())))
class ThreadJoinOnShutdown(unittest.TestCase):
def _run_and_join(self, script):
script = """if 1:
import sys, os, time, threading
# a thread, which waits for the main program to terminate
def joiningfunc(mainthread):
mainthread.join()
print 'end of thread'
\n""" + script
import subprocess
p = subprocess.Popen([sys.executable, "-c", script], stdout=subprocess.PIPE)
rc = p.wait()
self.assertEqual(p.stdout.read(), "end of main\nend of thread\n")
self.failIf(rc == 2, "interpreter was blocked")
self.failUnless(rc == 0, "Unexpected error")
def test_1_join_on_shutdown(self):
# The usual case: on exit, wait for a non-daemon thread
script = """if 1:
import os
t = threading.Thread(target=joiningfunc,
args=(threading.current_thread(),))
t.start()
time.sleep(0.1)
print 'end of main'
"""
self._run_and_join(script)
def test_2_join_in_forked_process(self):
# Like the test above, but from a forked interpreter
import os
if not hasattr(os, 'fork'):
return
script = """if 1:
childpid = os.fork()
if childpid != 0:
os.waitpid(childpid, 0)
sys.exit(0)
t = threading.Thread(target=joiningfunc,
args=(threading.current_thread(),))
t.start()
print 'end of main'
"""
self._run_and_join(script)
def test_3_join_in_forked_from_thread(self):
# Like the test above, but fork() was called from a worker thread
# In the forked process, the main Thread object must be marked as stopped.
import os
if not hasattr(os, 'fork'):
return
script = """if 1:
main_thread = threading.current_thread()
def worker():
childpid = os.fork()
if childpid != 0:
os.waitpid(childpid, 0)
sys.exit(0)
t = threading.Thread(target=joiningfunc,
args=(main_thread,))
print 'end of main'
t.start()
t.join() # Should not block: main_thread is already stopped
w = threading.Thread(target=worker)
w.start()
"""
self._run_and_join(script)
class ThreadingExceptionTests(unittest.TestCase):
# A RuntimeError should be raised if Thread.start() is called
# multiple times.
@ -363,7 +439,9 @@ class ThreadingExceptionTests(unittest.TestCase):
def test_main():
test.test_support.run_unittest(ThreadTests,
ThreadingExceptionTests)
ThreadJoinOnShutdown,
ThreadingExceptionTests,
)
if __name__ == "__main__":
test_main()

View File

@ -825,6 +825,37 @@ except ImportError:
from _threading_local import local
def _after_fork():
# This function is called by Python/ceval.c:PyEval_ReInitThreads which
# is called from PyOS_AfterFork. Here we cleanup threading module state
# that should not exist after a fork.
# Reset _active_limbo_lock, in case we forked while the lock was held
# by another (non-forked) thread. http://bugs.python.org/issue874900
global _active_limbo_lock
_active_limbo_lock = _allocate_lock()
# fork() only copied the current thread; clear references to others.
new_active = {}
current = current_thread()
with _active_limbo_lock:
for ident, thread in _active.iteritems():
if thread is current:
# There is only one active thread.
new_active[ident] = thread
else:
# All the others are already stopped.
# We don't call _Thread__stop() because it tries to acquire
# thread._Thread__block which could also have been held while
# we forked.
thread._Thread__stopped = True
_limbo.clear()
_active.clear()
_active.update(new_active)
assert len(_active) == 1
# Self-test code
def _test():

View File

@ -274,6 +274,9 @@ PyEval_ReleaseThread(PyThreadState *tstate)
void
PyEval_ReInitThreads(void)
{
PyObject *threading, *result;
PyThreadState *tstate;
if (!interpreter_lock)
return;
/*XXX Can't use PyThread_free_lock here because it does too
@ -283,6 +286,23 @@ PyEval_ReInitThreads(void)
interpreter_lock = PyThread_allocate_lock();
PyThread_acquire_lock(interpreter_lock, 1);
main_thread = PyThread_get_thread_ident();
/* Update the threading module with the new state.
*/
tstate = PyThreadState_GET();
threading = PyMapping_GetItemString(tstate->interp->modules,
"threading");
if (threading == NULL) {
/* threading not imported */
PyErr_Clear();
return;
}
result = PyObject_CallMethod(threading, "_after_fork", NULL);
if (result == NULL)
PyErr_WriteUnraisable(threading);
else
Py_DECREF(result);
Py_DECREF(threading);
}
#endif