diff --git a/Lib/multiprocessing/forkserver.py b/Lib/multiprocessing/forkserver.py index 7a952e2ba69..040b46e66a0 100644 --- a/Lib/multiprocessing/forkserver.py +++ b/Lib/multiprocessing/forkserver.py @@ -34,6 +34,7 @@ class ForkServer(object): def __init__(self): self._forkserver_address = None self._forkserver_alive_fd = None + self._forkserver_pid = None self._inherited_fds = None self._lock = threading.Lock() self._preload_modules = ['__main__'] @@ -90,8 +91,17 @@ class ForkServer(object): ''' with self._lock: semaphore_tracker.ensure_running() - if self._forkserver_alive_fd is not None: - return + if self._forkserver_pid is not None: + # forkserver was launched before, is it still running? + pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG) + if not pid: + # still alive + return + # dead, launch it again + os.close(self._forkserver_alive_fd) + self._forkserver_address = None + self._forkserver_alive_fd = None + self._forkserver_pid = None cmd = ('from multiprocessing.forkserver import main; ' + 'main(%d, %d, %r, **%r)') @@ -127,6 +137,7 @@ class ForkServer(object): os.close(alive_r) self._forkserver_address = address self._forkserver_alive_fd = alive_w + self._forkserver_pid = pid # # @@ -157,11 +168,11 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): # Dummy signal handler, doesn't do anything pass - # letting SIGINT through avoids KeyboardInterrupt tracebacks - # unblocking SIGCHLD allows the wakeup fd to notify our event loop handlers = { + # unblocking SIGCHLD allows the wakeup fd to notify our event loop signal.SIGCHLD: sigchld_handler, - signal.SIGINT: signal.SIG_DFL, + # protect the process from ^C + signal.SIGINT: signal.SIG_IGN, } old_handlers = {sig: signal.signal(sig, val) for (sig, val) in handlers.items()} diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index 69c0bd892c8..799146d8a3f 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -603,6 +603,54 @@ class _TestProcess(BaseTestCase): finally: setattr(sys, stream_name, old_stream) + @classmethod + def _sleep_and_set_event(self, evt, delay=0.0): + time.sleep(delay) + evt.set() + + def check_forkserver_death(self, signum): + # bpo-31308: if the forkserver process has died, we should still + # be able to create and run new Process instances (the forkserver + # is implicitly restarted). + if self.TYPE == 'threads': + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + sm = multiprocessing.get_start_method() + if sm != 'forkserver': + # The fork method by design inherits all fds from the parent, + # trying to go against it is a lost battle + self.skipTest('test not appropriate for {}'.format(sm)) + + from multiprocessing.forkserver import _forkserver + _forkserver.ensure_running() + + evt = self.Event() + proc = self.Process(target=self._sleep_and_set_event, args=(evt, 1.0)) + proc.start() + + pid = _forkserver._forkserver_pid + os.kill(pid, signum) + time.sleep(1.0) # give it time to die + + evt2 = self.Event() + proc2 = self.Process(target=self._sleep_and_set_event, args=(evt2,)) + proc2.start() + proc2.join() + self.assertTrue(evt2.is_set()) + self.assertEqual(proc2.exitcode, 0) + + proc.join() + self.assertTrue(evt.is_set()) + self.assertIn(proc.exitcode, (0, 255)) + + def test_forkserver_sigint(self): + # Catchable signal + self.check_forkserver_death(signal.SIGINT) + + def test_forkserver_sigkill(self): + # Uncatchable signal + if os.name != 'nt': + self.check_forkserver_death(signal.SIGKILL) + # # diff --git a/Misc/NEWS.d/next/Library/2017-08-30-17-59-36.bpo-31308.KbexyC.rst b/Misc/NEWS.d/next/Library/2017-08-30-17-59-36.bpo-31308.KbexyC.rst new file mode 100644 index 00000000000..6068b7fd32e --- /dev/null +++ b/Misc/NEWS.d/next/Library/2017-08-30-17-59-36.bpo-31308.KbexyC.rst @@ -0,0 +1,2 @@ +Make multiprocessing's forkserver process immune to Ctrl-C and other user interruptions. +If it crashes, restart it when necessary.