mirror of https://github.com/python/cpython
bpo-28699: fix abnormal behaviour of pools in multiprocessing.pool (GH-693)
an exception raised at the very first of an iterable would cause pools behave abnormally (swallow the exception or hang)
This commit is contained in:
parent
ec1f5df46e
commit
794623bdb2
|
@ -118,7 +118,7 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
|
|||
try:
|
||||
result = (True, func(*args, **kwds))
|
||||
except Exception as e:
|
||||
if wrap_exception:
|
||||
if wrap_exception and func is not _helper_reraises_exception:
|
||||
e = ExceptionWithTraceback(e, e.__traceback__)
|
||||
result = (False, e)
|
||||
try:
|
||||
|
@ -133,6 +133,10 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
|
|||
completed += 1
|
||||
util.debug('worker exiting after %d tasks' % completed)
|
||||
|
||||
def _helper_reraises_exception(ex):
|
||||
'Pickle-able helper function for use by _guarded_task_generation.'
|
||||
raise ex
|
||||
|
||||
#
|
||||
# Class representing a process pool
|
||||
#
|
||||
|
@ -277,6 +281,17 @@ class Pool(object):
|
|||
return self._map_async(func, iterable, starmapstar, chunksize,
|
||||
callback, error_callback)
|
||||
|
||||
def _guarded_task_generation(self, result_job, func, iterable):
|
||||
'''Provides a generator of tasks for imap and imap_unordered with
|
||||
appropriate handling for iterables which throw exceptions during
|
||||
iteration.'''
|
||||
try:
|
||||
i = -1
|
||||
for i, x in enumerate(iterable):
|
||||
yield (result_job, i, func, (x,), {})
|
||||
except Exception as e:
|
||||
yield (result_job, i+1, _helper_reraises_exception, (e,), {})
|
||||
|
||||
def imap(self, func, iterable, chunksize=1):
|
||||
'''
|
||||
Equivalent of `map()` -- can be MUCH slower than `Pool.map()`.
|
||||
|
@ -285,15 +300,23 @@ class Pool(object):
|
|||
raise ValueError("Pool not running")
|
||||
if chunksize == 1:
|
||||
result = IMapIterator(self._cache)
|
||||
self._taskqueue.put((((result._job, i, func, (x,), {})
|
||||
for i, x in enumerate(iterable)), result._set_length))
|
||||
self._taskqueue.put(
|
||||
(
|
||||
self._guarded_task_generation(result._job, func, iterable),
|
||||
result._set_length
|
||||
))
|
||||
return result
|
||||
else:
|
||||
assert chunksize > 1
|
||||
task_batches = Pool._get_tasks(func, iterable, chunksize)
|
||||
result = IMapIterator(self._cache)
|
||||
self._taskqueue.put((((result._job, i, mapstar, (x,), {})
|
||||
for i, x in enumerate(task_batches)), result._set_length))
|
||||
self._taskqueue.put(
|
||||
(
|
||||
self._guarded_task_generation(result._job,
|
||||
mapstar,
|
||||
task_batches),
|
||||
result._set_length
|
||||
))
|
||||
return (item for chunk in result for item in chunk)
|
||||
|
||||
def imap_unordered(self, func, iterable, chunksize=1):
|
||||
|
@ -304,15 +327,23 @@ class Pool(object):
|
|||
raise ValueError("Pool not running")
|
||||
if chunksize == 1:
|
||||
result = IMapUnorderedIterator(self._cache)
|
||||
self._taskqueue.put((((result._job, i, func, (x,), {})
|
||||
for i, x in enumerate(iterable)), result._set_length))
|
||||
self._taskqueue.put(
|
||||
(
|
||||
self._guarded_task_generation(result._job, func, iterable),
|
||||
result._set_length
|
||||
))
|
||||
return result
|
||||
else:
|
||||
assert chunksize > 1
|
||||
task_batches = Pool._get_tasks(func, iterable, chunksize)
|
||||
result = IMapUnorderedIterator(self._cache)
|
||||
self._taskqueue.put((((result._job, i, mapstar, (x,), {})
|
||||
for i, x in enumerate(task_batches)), result._set_length))
|
||||
self._taskqueue.put(
|
||||
(
|
||||
self._guarded_task_generation(result._job,
|
||||
mapstar,
|
||||
task_batches),
|
||||
result._set_length
|
||||
))
|
||||
return (item for chunk in result for item in chunk)
|
||||
|
||||
def apply_async(self, func, args=(), kwds={}, callback=None,
|
||||
|
@ -323,7 +354,7 @@ class Pool(object):
|
|||
if self._state != RUN:
|
||||
raise ValueError("Pool not running")
|
||||
result = ApplyResult(self._cache, callback, error_callback)
|
||||
self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
|
||||
self._taskqueue.put(([(result._job, 0, func, args, kwds)], None))
|
||||
return result
|
||||
|
||||
def map_async(self, func, iterable, chunksize=None, callback=None,
|
||||
|
@ -354,8 +385,14 @@ class Pool(object):
|
|||
task_batches = Pool._get_tasks(func, iterable, chunksize)
|
||||
result = MapResult(self._cache, chunksize, len(iterable), callback,
|
||||
error_callback=error_callback)
|
||||
self._taskqueue.put((((result._job, i, mapper, (x,), {})
|
||||
for i, x in enumerate(task_batches)), None))
|
||||
self._taskqueue.put(
|
||||
(
|
||||
self._guarded_task_generation(result._job,
|
||||
mapper,
|
||||
task_batches),
|
||||
None
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
|
@ -377,33 +414,27 @@ class Pool(object):
|
|||
|
||||
for taskseq, set_length in iter(taskqueue.get, None):
|
||||
task = None
|
||||
i = -1
|
||||
try:
|
||||
for i, task in enumerate(taskseq):
|
||||
# iterating taskseq cannot fail
|
||||
for task in taskseq:
|
||||
if thread._state:
|
||||
util.debug('task handler found thread._state != RUN')
|
||||
break
|
||||
try:
|
||||
put(task)
|
||||
except Exception as e:
|
||||
job, ind = task[:2]
|
||||
job, idx = task[:2]
|
||||
try:
|
||||
cache[job]._set(ind, (False, e))
|
||||
cache[job]._set(idx, (False, e))
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
if set_length:
|
||||
util.debug('doing set_length()')
|
||||
set_length(i+1)
|
||||
idx = task[1] if task else -1
|
||||
set_length(idx + 1)
|
||||
continue
|
||||
break
|
||||
except Exception as ex:
|
||||
job, ind = task[:2] if task else (0, 0)
|
||||
if job in cache:
|
||||
cache[job]._set(ind + 1, (False, ex))
|
||||
if set_length:
|
||||
util.debug('doing set_length()')
|
||||
set_length(i+1)
|
||||
finally:
|
||||
task = taskseq = job = None
|
||||
else:
|
||||
|
|
|
@ -1755,6 +1755,8 @@ class CountedObject(object):
|
|||
class SayWhenError(ValueError): pass
|
||||
|
||||
def exception_throwing_generator(total, when):
|
||||
if when == -1:
|
||||
raise SayWhenError("Somebody said when")
|
||||
for i in range(total):
|
||||
if i == when:
|
||||
raise SayWhenError("Somebody said when")
|
||||
|
@ -1833,6 +1835,32 @@ class _TestPool(BaseTestCase):
|
|||
except multiprocessing.TimeoutError:
|
||||
self.fail("pool.map_async with chunksize stalled on null list")
|
||||
|
||||
def test_map_handle_iterable_exception(self):
|
||||
if self.TYPE == 'manager':
|
||||
self.skipTest('test not appropriate for {}'.format(self.TYPE))
|
||||
|
||||
# SayWhenError seen at the very first of the iterable
|
||||
with self.assertRaises(SayWhenError):
|
||||
self.pool.map(sqr, exception_throwing_generator(1, -1), 1)
|
||||
# again, make sure it's reentrant
|
||||
with self.assertRaises(SayWhenError):
|
||||
self.pool.map(sqr, exception_throwing_generator(1, -1), 1)
|
||||
|
||||
with self.assertRaises(SayWhenError):
|
||||
self.pool.map(sqr, exception_throwing_generator(10, 3), 1)
|
||||
|
||||
class SpecialIterable:
|
||||
def __iter__(self):
|
||||
return self
|
||||
def __next__(self):
|
||||
raise SayWhenError
|
||||
def __len__(self):
|
||||
return 1
|
||||
with self.assertRaises(SayWhenError):
|
||||
self.pool.map(sqr, SpecialIterable(), 1)
|
||||
with self.assertRaises(SayWhenError):
|
||||
self.pool.map(sqr, SpecialIterable(), 1)
|
||||
|
||||
def test_async(self):
|
||||
res = self.pool.apply_async(sqr, (7, TIMEOUT1,))
|
||||
get = TimingWrapper(res.get)
|
||||
|
@ -1863,6 +1891,13 @@ class _TestPool(BaseTestCase):
|
|||
if self.TYPE == 'manager':
|
||||
self.skipTest('test not appropriate for {}'.format(self.TYPE))
|
||||
|
||||
# SayWhenError seen at the very first of the iterable
|
||||
it = self.pool.imap(sqr, exception_throwing_generator(1, -1), 1)
|
||||
self.assertRaises(SayWhenError, it.__next__)
|
||||
# again, make sure it's reentrant
|
||||
it = self.pool.imap(sqr, exception_throwing_generator(1, -1), 1)
|
||||
self.assertRaises(SayWhenError, it.__next__)
|
||||
|
||||
it = self.pool.imap(sqr, exception_throwing_generator(10, 3), 1)
|
||||
for i in range(3):
|
||||
self.assertEqual(next(it), i*i)
|
||||
|
@ -1889,6 +1924,17 @@ class _TestPool(BaseTestCase):
|
|||
if self.TYPE == 'manager':
|
||||
self.skipTest('test not appropriate for {}'.format(self.TYPE))
|
||||
|
||||
# SayWhenError seen at the very first of the iterable
|
||||
it = self.pool.imap_unordered(sqr,
|
||||
exception_throwing_generator(1, -1),
|
||||
1)
|
||||
self.assertRaises(SayWhenError, it.__next__)
|
||||
# again, make sure it's reentrant
|
||||
it = self.pool.imap_unordered(sqr,
|
||||
exception_throwing_generator(1, -1),
|
||||
1)
|
||||
self.assertRaises(SayWhenError, it.__next__)
|
||||
|
||||
it = self.pool.imap_unordered(sqr,
|
||||
exception_throwing_generator(10, 3),
|
||||
1)
|
||||
|
@ -1970,7 +2016,7 @@ class _TestPool(BaseTestCase):
|
|||
except Exception as e:
|
||||
exc = e
|
||||
else:
|
||||
raise AssertionError('expected RuntimeError')
|
||||
self.fail('expected RuntimeError')
|
||||
self.assertIs(type(exc), RuntimeError)
|
||||
self.assertEqual(exc.args, (123,))
|
||||
cause = exc.__cause__
|
||||
|
@ -1984,6 +2030,17 @@ class _TestPool(BaseTestCase):
|
|||
sys.excepthook(*sys.exc_info())
|
||||
self.assertIn('raise RuntimeError(123) # some comment',
|
||||
f1.getvalue())
|
||||
# _helper_reraises_exception should not make the error
|
||||
# a remote exception
|
||||
with self.Pool(1) as p:
|
||||
try:
|
||||
p.map(sqr, exception_throwing_generator(1, -1), 1)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
else:
|
||||
self.fail('expected SayWhenError')
|
||||
self.assertIs(type(exc), SayWhenError)
|
||||
self.assertIs(exc.__cause__, None)
|
||||
|
||||
@classmethod
|
||||
def _test_wrapped_exception(cls):
|
||||
|
|
|
@ -291,6 +291,10 @@ Extension Modules
|
|||
Library
|
||||
-------
|
||||
|
||||
- bpo-28699: Fixed a bug in pools in multiprocessing.pool that raising an
|
||||
exception at the very first of an iterable may swallow the exception or
|
||||
make the program hang. Patch by Davin Potts and Xiang Zhang.
|
||||
|
||||
- bpo-23890: unittest.TestCase.assertRaises() now manually breaks a reference
|
||||
cycle to not keep objects alive longer than expected.
|
||||
|
||||
|
|
Loading…
Reference in New Issue