diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py index bbe05a550c3..436662acaf5 100644 --- a/Lib/multiprocessing/pool.py +++ b/Lib/multiprocessing/pool.py @@ -592,6 +592,9 @@ class Pool(object): cache[job]._set(i, obj) except KeyError: pass + except Exception: + # Even if we raised we still want to handle callbacks + traceback.print_exc() task = job = obj = None while cache and thread._state != TERMINATE: @@ -609,6 +612,9 @@ class Pool(object): cache[job]._set(i, obj) except KeyError: pass + except Exception: + # Even if we raised we still want to handle callbacks + traceback.print_exc() task = job = obj = None if hasattr(outqueue, '_reader'): @@ -772,13 +778,15 @@ class ApplyResult(object): def _set(self, i, obj): self._success, self._value = obj - if self._callback and self._success: - self._callback(self._value) - if self._error_callback and not self._success: - self._error_callback(self._value) - self._event.set() - del self._cache[self._job] - self._pool = None + try: + if self._callback and self._success: + self._callback(self._value) + if self._error_callback and not self._success: + self._error_callback(self._value) + finally: + self._event.set() + del self._cache[self._job] + self._pool = None __class_getitem__ = classmethod(types.GenericAlias) @@ -809,11 +817,13 @@ class MapResult(ApplyResult): if success and self._success: self._value[i*self._chunksize:(i+1)*self._chunksize] = result if self._number_left == 0: - if self._callback: - self._callback(self._value) - del self._cache[self._job] - self._event.set() - self._pool = None + try: + if self._callback: + self._callback(self._value) + finally: + del self._cache[self._job] + self._event.set() + self._pool = None else: if not success and self._success: # only store first exception @@ -821,11 +831,13 @@ class MapResult(ApplyResult): self._value = result if self._number_left == 0: # only consider the result ready once all jobs are done - if self._error_callback: - self._error_callback(self._value) - del self._cache[self._job] - self._event.set() - self._pool = None + try: + if self._error_callback: + self._error_callback(self._value) + finally: + del self._cache[self._job] + self._event.set() + self._pool = None # # Class whose instances are returned by `Pool.imap()` diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index fd3b4303f03..dbaf251fd61 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -2741,6 +2741,39 @@ class _TestPoolWorkerErrors(BaseTestCase): p.close() p.join() +class _TestPoolResultHandlerErrors(BaseTestCase): + ALLOWED_TYPES = ('processes', ) + + def test_apply_async_callback_raises_exception(self): + p = multiprocessing.Pool(1) + + def job(): + return 1 + + def callback(value): + raise Exception() + + p.apply_async(job, callback=callback) + + self.assertTrue(p._result_handler.is_alive()) + p.close() + p.join() + + def test_map_async_callback_raises_exception(self): + p = multiprocessing.Pool(1) + + def job(value): + return value + + def callback(value): + raise Exception() + + p.map_async(job, [1], callback=callback) + + self.assertTrue(p._result_handler.is_alive()) + p.close() + p.join() + class _TestPoolWorkerLifetime(BaseTestCase): ALLOWED_TYPES = ('processes', ) @@ -5740,7 +5773,6 @@ def install_tests_in_module_dict(remote_globs, start_method): __module__ = remote_globs['__name__'] local_globs = globals() ALL_TYPES = {'processes', 'threads', 'manager'} - for name, base in local_globs.items(): if not isinstance(base, type): continue