bpo-28699: fix abnormal behaviour of pools in multiprocessing.pool (GH-693)

an exception raised at the very first of an iterable would cause pools behave abnormally
(swallow the exception or hang)
This commit is contained in:
Xiang Zhang 2017-03-29 11:58:54 +08:00 committed by GitHub
parent ec1f5df46e
commit 794623bdb2
3 changed files with 117 additions and 25 deletions

View File

@ -118,7 +118,7 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
try:
result = (True, func(*args, **kwds))
except Exception as e:
if wrap_exception:
if wrap_exception and func is not _helper_reraises_exception:
e = ExceptionWithTraceback(e, e.__traceback__)
result = (False, e)
try:
@ -133,6 +133,10 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
completed += 1
util.debug('worker exiting after %d tasks' % completed)
def _helper_reraises_exception(ex):
'Pickle-able helper function for use by _guarded_task_generation.'
raise ex
#
# Class representing a process pool
#
@ -277,6 +281,17 @@ class Pool(object):
return self._map_async(func, iterable, starmapstar, chunksize,
callback, error_callback)
def _guarded_task_generation(self, result_job, func, iterable):
'''Provides a generator of tasks for imap and imap_unordered with
appropriate handling for iterables which throw exceptions during
iteration.'''
try:
i = -1
for i, x in enumerate(iterable):
yield (result_job, i, func, (x,), {})
except Exception as e:
yield (result_job, i+1, _helper_reraises_exception, (e,), {})
def imap(self, func, iterable, chunksize=1):
'''
Equivalent of `map()` -- can be MUCH slower than `Pool.map()`.
@ -285,15 +300,23 @@ class Pool(object):
raise ValueError("Pool not running")
if chunksize == 1:
result = IMapIterator(self._cache)
self._taskqueue.put((((result._job, i, func, (x,), {})
for i, x in enumerate(iterable)), result._set_length))
self._taskqueue.put(
(
self._guarded_task_generation(result._job, func, iterable),
result._set_length
))
return result
else:
assert chunksize > 1
task_batches = Pool._get_tasks(func, iterable, chunksize)
result = IMapIterator(self._cache)
self._taskqueue.put((((result._job, i, mapstar, (x,), {})
for i, x in enumerate(task_batches)), result._set_length))
self._taskqueue.put(
(
self._guarded_task_generation(result._job,
mapstar,
task_batches),
result._set_length
))
return (item for chunk in result for item in chunk)
def imap_unordered(self, func, iterable, chunksize=1):
@ -304,15 +327,23 @@ class Pool(object):
raise ValueError("Pool not running")
if chunksize == 1:
result = IMapUnorderedIterator(self._cache)
self._taskqueue.put((((result._job, i, func, (x,), {})
for i, x in enumerate(iterable)), result._set_length))
self._taskqueue.put(
(
self._guarded_task_generation(result._job, func, iterable),
result._set_length
))
return result
else:
assert chunksize > 1
task_batches = Pool._get_tasks(func, iterable, chunksize)
result = IMapUnorderedIterator(self._cache)
self._taskqueue.put((((result._job, i, mapstar, (x,), {})
for i, x in enumerate(task_batches)), result._set_length))
self._taskqueue.put(
(
self._guarded_task_generation(result._job,
mapstar,
task_batches),
result._set_length
))
return (item for chunk in result for item in chunk)
def apply_async(self, func, args=(), kwds={}, callback=None,
@ -323,7 +354,7 @@ class Pool(object):
if self._state != RUN:
raise ValueError("Pool not running")
result = ApplyResult(self._cache, callback, error_callback)
self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
self._taskqueue.put(([(result._job, 0, func, args, kwds)], None))
return result
def map_async(self, func, iterable, chunksize=None, callback=None,
@ -354,8 +385,14 @@ class Pool(object):
task_batches = Pool._get_tasks(func, iterable, chunksize)
result = MapResult(self._cache, chunksize, len(iterable), callback,
error_callback=error_callback)
self._taskqueue.put((((result._job, i, mapper, (x,), {})
for i, x in enumerate(task_batches)), None))
self._taskqueue.put(
(
self._guarded_task_generation(result._job,
mapper,
task_batches),
None
)
)
return result
@staticmethod
@ -377,33 +414,27 @@ class Pool(object):
for taskseq, set_length in iter(taskqueue.get, None):
task = None
i = -1
try:
for i, task in enumerate(taskseq):
# iterating taskseq cannot fail
for task in taskseq:
if thread._state:
util.debug('task handler found thread._state != RUN')
break
try:
put(task)
except Exception as e:
job, ind = task[:2]
job, idx = task[:2]
try:
cache[job]._set(ind, (False, e))
cache[job]._set(idx, (False, e))
except KeyError:
pass
else:
if set_length:
util.debug('doing set_length()')
set_length(i+1)
idx = task[1] if task else -1
set_length(idx + 1)
continue
break
except Exception as ex:
job, ind = task[:2] if task else (0, 0)
if job in cache:
cache[job]._set(ind + 1, (False, ex))
if set_length:
util.debug('doing set_length()')
set_length(i+1)
finally:
task = taskseq = job = None
else:

View File

@ -1755,6 +1755,8 @@ class CountedObject(object):
class SayWhenError(ValueError): pass
def exception_throwing_generator(total, when):
if when == -1:
raise SayWhenError("Somebody said when")
for i in range(total):
if i == when:
raise SayWhenError("Somebody said when")
@ -1833,6 +1835,32 @@ class _TestPool(BaseTestCase):
except multiprocessing.TimeoutError:
self.fail("pool.map_async with chunksize stalled on null list")
def test_map_handle_iterable_exception(self):
if self.TYPE == 'manager':
self.skipTest('test not appropriate for {}'.format(self.TYPE))
# SayWhenError seen at the very first of the iterable
with self.assertRaises(SayWhenError):
self.pool.map(sqr, exception_throwing_generator(1, -1), 1)
# again, make sure it's reentrant
with self.assertRaises(SayWhenError):
self.pool.map(sqr, exception_throwing_generator(1, -1), 1)
with self.assertRaises(SayWhenError):
self.pool.map(sqr, exception_throwing_generator(10, 3), 1)
class SpecialIterable:
def __iter__(self):
return self
def __next__(self):
raise SayWhenError
def __len__(self):
return 1
with self.assertRaises(SayWhenError):
self.pool.map(sqr, SpecialIterable(), 1)
with self.assertRaises(SayWhenError):
self.pool.map(sqr, SpecialIterable(), 1)
def test_async(self):
res = self.pool.apply_async(sqr, (7, TIMEOUT1,))
get = TimingWrapper(res.get)
@ -1863,6 +1891,13 @@ class _TestPool(BaseTestCase):
if self.TYPE == 'manager':
self.skipTest('test not appropriate for {}'.format(self.TYPE))
# SayWhenError seen at the very first of the iterable
it = self.pool.imap(sqr, exception_throwing_generator(1, -1), 1)
self.assertRaises(SayWhenError, it.__next__)
# again, make sure it's reentrant
it = self.pool.imap(sqr, exception_throwing_generator(1, -1), 1)
self.assertRaises(SayWhenError, it.__next__)
it = self.pool.imap(sqr, exception_throwing_generator(10, 3), 1)
for i in range(3):
self.assertEqual(next(it), i*i)
@ -1889,6 +1924,17 @@ class _TestPool(BaseTestCase):
if self.TYPE == 'manager':
self.skipTest('test not appropriate for {}'.format(self.TYPE))
# SayWhenError seen at the very first of the iterable
it = self.pool.imap_unordered(sqr,
exception_throwing_generator(1, -1),
1)
self.assertRaises(SayWhenError, it.__next__)
# again, make sure it's reentrant
it = self.pool.imap_unordered(sqr,
exception_throwing_generator(1, -1),
1)
self.assertRaises(SayWhenError, it.__next__)
it = self.pool.imap_unordered(sqr,
exception_throwing_generator(10, 3),
1)
@ -1970,7 +2016,7 @@ class _TestPool(BaseTestCase):
except Exception as e:
exc = e
else:
raise AssertionError('expected RuntimeError')
self.fail('expected RuntimeError')
self.assertIs(type(exc), RuntimeError)
self.assertEqual(exc.args, (123,))
cause = exc.__cause__
@ -1984,6 +2030,17 @@ class _TestPool(BaseTestCase):
sys.excepthook(*sys.exc_info())
self.assertIn('raise RuntimeError(123) # some comment',
f1.getvalue())
# _helper_reraises_exception should not make the error
# a remote exception
with self.Pool(1) as p:
try:
p.map(sqr, exception_throwing_generator(1, -1), 1)
except Exception as e:
exc = e
else:
self.fail('expected SayWhenError')
self.assertIs(type(exc), SayWhenError)
self.assertIs(exc.__cause__, None)
@classmethod
def _test_wrapped_exception(cls):

View File

@ -291,6 +291,10 @@ Extension Modules
Library
-------
- bpo-28699: Fixed a bug in pools in multiprocessing.pool that raising an
exception at the very first of an iterable may swallow the exception or
make the program hang. Patch by Davin Potts and Xiang Zhang.
- bpo-23890: unittest.TestCase.assertRaises() now manually breaks a reference
cycle to not keep objects alive longer than expected.