raise a ValueError instead of an AssertionError when pool is an invalid state

This commit is contained in:
Benjamin Peterson 2012-09-25 12:45:42 -04:00
parent 3331a20464
commit 3095f4724e
2 changed files with 10 additions and 8 deletions

View File

@ -225,7 +225,6 @@ class Pool(object):
Apply `func` to each element in `iterable`, collecting the results Apply `func` to each element in `iterable`, collecting the results
in a list that is returned. in a list that is returned.
''' '''
assert self._state == RUN
return self._map_async(func, iterable, mapstar, chunksize).get() return self._map_async(func, iterable, mapstar, chunksize).get()
def starmap(self, func, iterable, chunksize=None): def starmap(self, func, iterable, chunksize=None):
@ -234,7 +233,6 @@ class Pool(object):
be iterables as well and will be unpacked as arguments. Hence be iterables as well and will be unpacked as arguments. Hence
`func` and (a, b) becomes func(a, b). `func` and (a, b) becomes func(a, b).
''' '''
assert self._state == RUN
return self._map_async(func, iterable, starmapstar, chunksize).get() return self._map_async(func, iterable, starmapstar, chunksize).get()
def starmap_async(self, func, iterable, chunksize=None, callback=None, def starmap_async(self, func, iterable, chunksize=None, callback=None,
@ -242,7 +240,6 @@ class Pool(object):
''' '''
Asynchronous version of `starmap()` method. Asynchronous version of `starmap()` method.
''' '''
assert self._state == RUN
return self._map_async(func, iterable, starmapstar, chunksize, return self._map_async(func, iterable, starmapstar, chunksize,
callback, error_callback) callback, error_callback)
@ -250,7 +247,8 @@ class Pool(object):
''' '''
Equivalent of `map()` -- can be MUCH slower than `Pool.map()`. Equivalent of `map()` -- can be MUCH slower than `Pool.map()`.
''' '''
assert self._state == RUN if self._state != RUN:
raise ValueError("Pool not running")
if chunksize == 1: if chunksize == 1:
result = IMapIterator(self._cache) result = IMapIterator(self._cache)
self._taskqueue.put((((result._job, i, func, (x,), {}) self._taskqueue.put((((result._job, i, func, (x,), {})
@ -268,7 +266,8 @@ class Pool(object):
''' '''
Like `imap()` method but ordering of results is arbitrary. Like `imap()` method but ordering of results is arbitrary.
''' '''
assert self._state == RUN if self._state != RUN:
raise ValueError("Pool not running")
if chunksize == 1: if chunksize == 1:
result = IMapUnorderedIterator(self._cache) result = IMapUnorderedIterator(self._cache)
self._taskqueue.put((((result._job, i, func, (x,), {}) self._taskqueue.put((((result._job, i, func, (x,), {})
@ -287,7 +286,8 @@ class Pool(object):
''' '''
Asynchronous version of `apply()` method. Asynchronous version of `apply()` method.
''' '''
assert self._state == RUN if self._state != RUN:
raise ValueError("Pool not running")
result = ApplyResult(self._cache, callback, error_callback) result = ApplyResult(self._cache, callback, error_callback)
self._taskqueue.put(([(result._job, None, func, args, kwds)], None)) self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
return result return result
@ -297,7 +297,6 @@ class Pool(object):
''' '''
Asynchronous version of `map()` method. Asynchronous version of `map()` method.
''' '''
assert self._state == RUN
return self._map_async(func, iterable, mapstar, chunksize) return self._map_async(func, iterable, mapstar, chunksize)
def _map_async(self, func, iterable, mapper, chunksize=None, callback=None, def _map_async(self, func, iterable, mapper, chunksize=None, callback=None,
@ -305,6 +304,8 @@ class Pool(object):
''' '''
Helper function to implement map, starmap and their async counterparts. Helper function to implement map, starmap and their async counterparts.
''' '''
if self._state != RUN:
raise ValueError("Pool not running")
if not hasattr(iterable, '__len__'): if not hasattr(iterable, '__len__'):
iterable = list(iterable) iterable = list(iterable)

View File

@ -1727,7 +1727,8 @@ class _TestPool(BaseTestCase):
with multiprocessing.Pool(2) as p: with multiprocessing.Pool(2) as p:
r = p.map_async(sqr, L) r = p.map_async(sqr, L)
self.assertEqual(r.get(), expected) self.assertEqual(r.get(), expected)
self.assertRaises(AssertionError, p.map_async, sqr, L) print(p._state)
self.assertRaises(ValueError, p.map_async, sqr, L)
def raising(): def raising():
raise KeyError("key") raise KeyError("key")