bpo-35378: Fix multiprocessing.Pool references (GH-11627)
Changes in this commit: 1. Use a _strong_ reference between the Pool and associated iterators 2. Rework PR #8450 to eliminate a cycle in the Pool. There is no test in this commit because any test that automatically tests this behaviour needs to eliminate the pool before joining the pool to check that the pool object is garbaged collected/does not hang. But doing this will potentially leak threads and processes (see https://bugs.python.org/issue35413).
This commit is contained in:
parent
4b250fc1da
commit
3766f18c52
|
@ -151,8 +151,9 @@ class Pool(object):
|
|||
'''
|
||||
_wrap_exception = True
|
||||
|
||||
def Process(self, *args, **kwds):
|
||||
return self._ctx.Process(*args, **kwds)
|
||||
@staticmethod
|
||||
def Process(ctx, *args, **kwds):
|
||||
return ctx.Process(*args, **kwds)
|
||||
|
||||
def __init__(self, processes=None, initializer=None, initargs=(),
|
||||
maxtasksperchild=None, context=None):
|
||||
|
@ -190,7 +191,10 @@ class Pool(object):
|
|||
|
||||
self._worker_handler = threading.Thread(
|
||||
target=Pool._handle_workers,
|
||||
args=(self, )
|
||||
args=(self._cache, self._taskqueue, self._ctx, self.Process,
|
||||
self._processes, self._pool, self._inqueue, self._outqueue,
|
||||
self._initializer, self._initargs, self._maxtasksperchild,
|
||||
self._wrap_exception)
|
||||
)
|
||||
self._worker_handler.daemon = True
|
||||
self._worker_handler._state = RUN
|
||||
|
@ -236,43 +240,61 @@ class Pool(object):
|
|||
f'state={self._state} '
|
||||
f'pool_size={len(self._pool)}>')
|
||||
|
||||
def _join_exited_workers(self):
|
||||
@staticmethod
|
||||
def _join_exited_workers(pool):
|
||||
"""Cleanup after any worker processes which have exited due to reaching
|
||||
their specified lifetime. Returns True if any workers were cleaned up.
|
||||
"""
|
||||
cleaned = False
|
||||
for i in reversed(range(len(self._pool))):
|
||||
worker = self._pool[i]
|
||||
for i in reversed(range(len(pool))):
|
||||
worker = pool[i]
|
||||
if worker.exitcode is not None:
|
||||
# worker exited
|
||||
util.debug('cleaning up worker %d' % i)
|
||||
worker.join()
|
||||
cleaned = True
|
||||
del self._pool[i]
|
||||
del pool[i]
|
||||
return cleaned
|
||||
|
||||
def _repopulate_pool(self):
|
||||
return self._repopulate_pool_static(self._ctx, self.Process,
|
||||
self._processes,
|
||||
self._pool, self._inqueue,
|
||||
self._outqueue, self._initializer,
|
||||
self._initargs,
|
||||
self._maxtasksperchild,
|
||||
self._wrap_exception)
|
||||
|
||||
@staticmethod
|
||||
def _repopulate_pool_static(ctx, Process, processes, pool, inqueue,
|
||||
outqueue, initializer, initargs,
|
||||
maxtasksperchild, wrap_exception):
|
||||
"""Bring the number of pool processes up to the specified number,
|
||||
for use after reaping workers which have exited.
|
||||
"""
|
||||
for i in range(self._processes - len(self._pool)):
|
||||
w = self.Process(target=worker,
|
||||
args=(self._inqueue, self._outqueue,
|
||||
self._initializer,
|
||||
self._initargs, self._maxtasksperchild,
|
||||
self._wrap_exception)
|
||||
)
|
||||
for i in range(processes - len(pool)):
|
||||
w = Process(ctx, target=worker,
|
||||
args=(inqueue, outqueue,
|
||||
initializer,
|
||||
initargs, maxtasksperchild,
|
||||
wrap_exception))
|
||||
w.name = w.name.replace('Process', 'PoolWorker')
|
||||
w.daemon = True
|
||||
w.start()
|
||||
self._pool.append(w)
|
||||
pool.append(w)
|
||||
util.debug('added worker')
|
||||
|
||||
def _maintain_pool(self):
|
||||
@staticmethod
|
||||
def _maintain_pool(ctx, Process, processes, pool, inqueue, outqueue,
|
||||
initializer, initargs, maxtasksperchild,
|
||||
wrap_exception):
|
||||
"""Clean up any exited workers and start replacements for them.
|
||||
"""
|
||||
if self._join_exited_workers():
|
||||
self._repopulate_pool()
|
||||
if Pool._join_exited_workers(pool):
|
||||
Pool._repopulate_pool_static(ctx, Process, processes, pool,
|
||||
inqueue, outqueue, initializer,
|
||||
initargs, maxtasksperchild,
|
||||
wrap_exception)
|
||||
|
||||
def _setup_queues(self):
|
||||
self._inqueue = self._ctx.SimpleQueue()
|
||||
|
@ -331,7 +353,7 @@ class Pool(object):
|
|||
'''
|
||||
self._check_running()
|
||||
if chunksize == 1:
|
||||
result = IMapIterator(self._cache)
|
||||
result = IMapIterator(self)
|
||||
self._taskqueue.put(
|
||||
(
|
||||
self._guarded_task_generation(result._job, func, iterable),
|
||||
|
@ -344,7 +366,7 @@ class Pool(object):
|
|||
"Chunksize must be 1+, not {0:n}".format(
|
||||
chunksize))
|
||||
task_batches = Pool._get_tasks(func, iterable, chunksize)
|
||||
result = IMapIterator(self._cache)
|
||||
result = IMapIterator(self)
|
||||
self._taskqueue.put(
|
||||
(
|
||||
self._guarded_task_generation(result._job,
|
||||
|
@ -360,7 +382,7 @@ class Pool(object):
|
|||
'''
|
||||
self._check_running()
|
||||
if chunksize == 1:
|
||||
result = IMapUnorderedIterator(self._cache)
|
||||
result = IMapUnorderedIterator(self)
|
||||
self._taskqueue.put(
|
||||
(
|
||||
self._guarded_task_generation(result._job, func, iterable),
|
||||
|
@ -372,7 +394,7 @@ class Pool(object):
|
|||
raise ValueError(
|
||||
"Chunksize must be 1+, not {0!r}".format(chunksize))
|
||||
task_batches = Pool._get_tasks(func, iterable, chunksize)
|
||||
result = IMapUnorderedIterator(self._cache)
|
||||
result = IMapUnorderedIterator(self)
|
||||
self._taskqueue.put(
|
||||
(
|
||||
self._guarded_task_generation(result._job,
|
||||
|
@ -388,7 +410,7 @@ class Pool(object):
|
|||
Asynchronous version of `apply()` method.
|
||||
'''
|
||||
self._check_running()
|
||||
result = ApplyResult(self._cache, callback, error_callback)
|
||||
result = ApplyResult(self, callback, error_callback)
|
||||
self._taskqueue.put(([(result._job, 0, func, args, kwds)], None))
|
||||
return result
|
||||
|
||||
|
@ -417,7 +439,7 @@ class Pool(object):
|
|||
chunksize = 0
|
||||
|
||||
task_batches = Pool._get_tasks(func, iterable, chunksize)
|
||||
result = MapResult(self._cache, chunksize, len(iterable), callback,
|
||||
result = MapResult(self, chunksize, len(iterable), callback,
|
||||
error_callback=error_callback)
|
||||
self._taskqueue.put(
|
||||
(
|
||||
|
@ -430,16 +452,20 @@ class Pool(object):
|
|||
return result
|
||||
|
||||
@staticmethod
|
||||
def _handle_workers(pool):
|
||||
def _handle_workers(cache, taskqueue, ctx, Process, processes, pool,
|
||||
inqueue, outqueue, initializer, initargs,
|
||||
maxtasksperchild, wrap_exception):
|
||||
thread = threading.current_thread()
|
||||
|
||||
# Keep maintaining workers until the cache gets drained, unless the pool
|
||||
# is terminated.
|
||||
while thread._state == RUN or (pool._cache and thread._state != TERMINATE):
|
||||
pool._maintain_pool()
|
||||
while thread._state == RUN or (cache and thread._state != TERMINATE):
|
||||
Pool._maintain_pool(ctx, Process, processes, pool, inqueue,
|
||||
outqueue, initializer, initargs,
|
||||
maxtasksperchild, wrap_exception)
|
||||
time.sleep(0.1)
|
||||
# send sentinel to stop workers
|
||||
pool._taskqueue.put(None)
|
||||
taskqueue.put(None)
|
||||
util.debug('worker handler exiting')
|
||||
|
||||
@staticmethod
|
||||
|
@ -656,13 +682,14 @@ class Pool(object):
|
|||
|
||||
class ApplyResult(object):
|
||||
|
||||
def __init__(self, cache, callback, error_callback):
|
||||
def __init__(self, pool, callback, error_callback):
|
||||
self._pool = pool
|
||||
self._event = threading.Event()
|
||||
self._job = next(job_counter)
|
||||
self._cache = cache
|
||||
self._cache = pool._cache
|
||||
self._callback = callback
|
||||
self._error_callback = error_callback
|
||||
cache[self._job] = self
|
||||
self._cache[self._job] = self
|
||||
|
||||
def ready(self):
|
||||
return self._event.is_set()
|
||||
|
@ -692,6 +719,7 @@ class ApplyResult(object):
|
|||
self._error_callback(self._value)
|
||||
self._event.set()
|
||||
del self._cache[self._job]
|
||||
self._pool = None
|
||||
|
||||
AsyncResult = ApplyResult # create alias -- see #17805
|
||||
|
||||
|
@ -701,8 +729,8 @@ AsyncResult = ApplyResult # create alias -- see #17805
|
|||
|
||||
class MapResult(ApplyResult):
|
||||
|
||||
def __init__(self, cache, chunksize, length, callback, error_callback):
|
||||
ApplyResult.__init__(self, cache, callback,
|
||||
def __init__(self, pool, chunksize, length, callback, error_callback):
|
||||
ApplyResult.__init__(self, pool, callback,
|
||||
error_callback=error_callback)
|
||||
self._success = True
|
||||
self._value = [None] * length
|
||||
|
@ -710,7 +738,7 @@ class MapResult(ApplyResult):
|
|||
if chunksize <= 0:
|
||||
self._number_left = 0
|
||||
self._event.set()
|
||||
del cache[self._job]
|
||||
del self._cache[self._job]
|
||||
else:
|
||||
self._number_left = length//chunksize + bool(length % chunksize)
|
||||
|
||||
|
@ -724,6 +752,7 @@ class MapResult(ApplyResult):
|
|||
self._callback(self._value)
|
||||
del self._cache[self._job]
|
||||
self._event.set()
|
||||
self._pool = None
|
||||
else:
|
||||
if not success and self._success:
|
||||
# only store first exception
|
||||
|
@ -735,6 +764,7 @@ class MapResult(ApplyResult):
|
|||
self._error_callback(self._value)
|
||||
del self._cache[self._job]
|
||||
self._event.set()
|
||||
self._pool = None
|
||||
|
||||
#
|
||||
# Class whose instances are returned by `Pool.imap()`
|
||||
|
@ -742,15 +772,16 @@ class MapResult(ApplyResult):
|
|||
|
||||
class IMapIterator(object):
|
||||
|
||||
def __init__(self, cache):
|
||||
def __init__(self, pool):
|
||||
self._pool = pool
|
||||
self._cond = threading.Condition(threading.Lock())
|
||||
self._job = next(job_counter)
|
||||
self._cache = cache
|
||||
self._cache = pool._cache
|
||||
self._items = collections.deque()
|
||||
self._index = 0
|
||||
self._length = None
|
||||
self._unsorted = {}
|
||||
cache[self._job] = self
|
||||
self._cache[self._job] = self
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
@ -761,12 +792,14 @@ class IMapIterator(object):
|
|||
item = self._items.popleft()
|
||||
except IndexError:
|
||||
if self._index == self._length:
|
||||
self._pool = None
|
||||
raise StopIteration from None
|
||||
self._cond.wait(timeout)
|
||||
try:
|
||||
item = self._items.popleft()
|
||||
except IndexError:
|
||||
if self._index == self._length:
|
||||
self._pool = None
|
||||
raise StopIteration from None
|
||||
raise TimeoutError from None
|
||||
|
||||
|
@ -792,6 +825,7 @@ class IMapIterator(object):
|
|||
|
||||
if self._index == self._length:
|
||||
del self._cache[self._job]
|
||||
self._pool = None
|
||||
|
||||
def _set_length(self, length):
|
||||
with self._cond:
|
||||
|
@ -799,6 +833,7 @@ class IMapIterator(object):
|
|||
if self._index == self._length:
|
||||
self._cond.notify()
|
||||
del self._cache[self._job]
|
||||
self._pool = None
|
||||
|
||||
#
|
||||
# Class whose instances are returned by `Pool.imap_unordered()`
|
||||
|
@ -813,6 +848,7 @@ class IMapUnorderedIterator(IMapIterator):
|
|||
self._cond.notify()
|
||||
if self._index == self._length:
|
||||
del self._cache[self._job]
|
||||
self._pool = None
|
||||
|
||||
#
|
||||
#
|
||||
|
@ -822,7 +858,7 @@ class ThreadPool(Pool):
|
|||
_wrap_exception = False
|
||||
|
||||
@staticmethod
|
||||
def Process(*args, **kwds):
|
||||
def Process(ctx, *args, **kwds):
|
||||
from .dummy import Process
|
||||
return Process(*args, **kwds)
|
||||
|
||||
|
|
|
@ -2593,7 +2593,6 @@ class _TestPool(BaseTestCase):
|
|||
pool = None
|
||||
support.gc_collect()
|
||||
|
||||
|
||||
def raising():
|
||||
raise KeyError("key")
|
||||
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
Fix a reference issue inside :class:`multiprocessing.Pool` that caused
|
||||
the pool to remain alive if it was deleted without being closed or
|
||||
terminated explicitly. A new strong reference is added to the pool
|
||||
iterators to link the lifetime of the pool to the lifetime of its
|
||||
iterators so the pool does not get destroyed if a pool iterator is
|
||||
still alive.
|
Loading…
Reference in New Issue