Issue #21565: multiprocessing: use contex-manager protocol for synchronization

primitives.
This commit is contained in:
Charles-François Natali 2014-05-25 14:12:12 +01:00
parent 1691e35953
commit a924fc7abc
9 changed files with 43 additions and 102 deletions

View File

@ -1320,6 +1320,9 @@ processes.
Note that accessing the ctypes object through the wrapper can be a lot slower
than accessing the raw ctypes object.
.. versionchanged:: 3.5
Synchronized objects support the :term:`context manager` protocol.
The table below compares the syntax for creating shared ctypes objects from
shared memory with the normal ctypes syntax. (In the table ``MyStruct`` is some

View File

@ -59,9 +59,8 @@ class Connection(object):
return True
if timeout <= 0.0:
return False
self._in.not_empty.acquire()
self._in.not_empty.wait(timeout)
self._in.not_empty.release()
with self._in.not_empty:
self._in.not_empty.wait(timeout)
return self._in.qsize() > 0
def close(self):

View File

@ -216,9 +216,8 @@ class Heap(object):
assert 0 <= size < sys.maxsize
if os.getpid() != self._lastpid:
self.__init__() # reinitialize after fork
self._lock.acquire()
self._free_pending_blocks()
try:
with self._lock:
self._free_pending_blocks()
size = self._roundup(max(size,1), self._alignment)
(arena, start, stop) = self._malloc(size)
new_stop = start + size
@ -227,8 +226,6 @@ class Heap(object):
block = (arena, start, new_stop)
self._allocated_blocks.add(block)
return block
finally:
self._lock.release()
#
# Class representing a chunk of an mmap -- can be inherited by child process

View File

@ -306,8 +306,7 @@ class Server(object):
'''
Return some info --- useful to spot problems with refcounting
'''
self.mutex.acquire()
try:
with self.mutex:
result = []
keys = list(self.id_to_obj.keys())
keys.sort()
@ -317,8 +316,6 @@ class Server(object):
(ident, self.id_to_refcount[ident],
str(self.id_to_obj[ident][0])[:75]))
return '\n'.join(result)
finally:
self.mutex.release()
def number_of_objects(self, c):
'''
@ -343,8 +340,7 @@ class Server(object):
'''
Create a new shared object and return its id
'''
self.mutex.acquire()
try:
with self.mutex:
callable, exposed, method_to_typeid, proxytype = \
self.registry[typeid]
@ -374,8 +370,6 @@ class Server(object):
# has been created.
self.incref(c, ident)
return ident, tuple(exposed)
finally:
self.mutex.release()
def get_methods(self, c, token):
'''
@ -392,22 +386,16 @@ class Server(object):
self.serve_client(c)
def incref(self, c, ident):
self.mutex.acquire()
try:
with self.mutex:
self.id_to_refcount[ident] += 1
finally:
self.mutex.release()
def decref(self, c, ident):
self.mutex.acquire()
try:
with self.mutex:
assert self.id_to_refcount[ident] >= 1
self.id_to_refcount[ident] -= 1
if self.id_to_refcount[ident] == 0:
del self.id_to_obj[ident], self.id_to_refcount[ident]
util.debug('disposing of obj with id %r', ident)
finally:
self.mutex.release()
#
# Class to represent state of a manager
@ -671,14 +659,11 @@ class BaseProxy(object):
def __init__(self, token, serializer, manager=None,
authkey=None, exposed=None, incref=True):
BaseProxy._mutex.acquire()
try:
with BaseProxy._mutex:
tls_idset = BaseProxy._address_to_local.get(token.address, None)
if tls_idset is None:
tls_idset = util.ForkAwareLocal(), ProcessLocalSet()
BaseProxy._address_to_local[token.address] = tls_idset
finally:
BaseProxy._mutex.release()
# self._tls is used to record the connection used by this
# thread to communicate with the manager at token.address

View File

@ -666,8 +666,7 @@ class IMapIterator(object):
return self
def next(self, timeout=None):
self._cond.acquire()
try:
with self._cond:
try:
item = self._items.popleft()
except IndexError:
@ -680,8 +679,6 @@ class IMapIterator(object):
if self._index == self._length:
raise StopIteration
raise TimeoutError
finally:
self._cond.release()
success, value = item
if success:
@ -691,8 +688,7 @@ class IMapIterator(object):
__next__ = next # XXX
def _set(self, i, obj):
self._cond.acquire()
try:
with self._cond:
if self._index == i:
self._items.append(obj)
self._index += 1
@ -706,18 +702,13 @@ class IMapIterator(object):
if self._index == self._length:
del self._cache[self._job]
finally:
self._cond.release()
def _set_length(self, length):
self._cond.acquire()
try:
with self._cond:
self._length = length
if self._index == self._length:
self._cond.notify()
del self._cache[self._job]
finally:
self._cond.release()
#
# Class whose instances are returned by `Pool.imap_unordered()`
@ -726,15 +717,12 @@ class IMapIterator(object):
class IMapUnorderedIterator(IMapIterator):
def _set(self, i, obj):
self._cond.acquire()
try:
with self._cond:
self._items.append(obj)
self._index += 1
self._cond.notify()
if self._index == self._length:
del self._cache[self._job]
finally:
self._cond.release()
#
#
@ -760,10 +748,7 @@ class ThreadPool(Pool):
@staticmethod
def _help_stuff_finish(inqueue, task_handler, size):
# put sentinels at head of inqueue to make workers finish
inqueue.not_empty.acquire()
try:
with inqueue.not_empty:
inqueue.queue.clear()
inqueue.queue.extend([None] * size)
inqueue.not_empty.notify_all()
finally:
inqueue.not_empty.release()

View File

@ -81,14 +81,11 @@ class Queue(object):
if not self._sem.acquire(block, timeout):
raise Full
self._notempty.acquire()
try:
with self._notempty:
if self._thread is None:
self._start_thread()
self._buffer.append(obj)
self._notempty.notify()
finally:
self._notempty.release()
def get(self, block=True, timeout=None):
if block and timeout is None:
@ -201,12 +198,9 @@ class Queue(object):
@staticmethod
def _finalize_close(buffer, notempty):
debug('telling queue thread to quit')
notempty.acquire()
try:
with notempty:
buffer.append(_sentinel)
notempty.notify()
finally:
notempty.release()
@staticmethod
def _feed(buffer, notempty, send_bytes, writelock, close, ignore_epipe):
@ -295,35 +289,24 @@ class JoinableQueue(Queue):
if not self._sem.acquire(block, timeout):
raise Full
self._notempty.acquire()
self._cond.acquire()
try:
with self._notempty, self._cond:
if self._thread is None:
self._start_thread()
self._buffer.append(obj)
self._unfinished_tasks.release()
self._notempty.notify()
finally:
self._cond.release()
self._notempty.release()
def task_done(self):
self._cond.acquire()
try:
with self._cond:
if not self._unfinished_tasks.acquire(False):
raise ValueError('task_done() called too many times')
if self._unfinished_tasks._semlock._is_zero():
self._cond.notify_all()
finally:
self._cond.release()
def join(self):
self._cond.acquire()
try:
with self._cond:
if not self._unfinished_tasks._semlock._is_zero():
self._cond.wait()
finally:
self._cond.release()
#
# Simplified Queue type -- really just a locked pipe

View File

@ -188,6 +188,12 @@ class SynchronizedBase(object):
self.acquire = self._lock.acquire
self.release = self._lock.release
def __enter__(self):
return self._lock.__enter__()
def __exit__(self, *args):
return self._lock.__exit__(*args)
def __reduce__(self):
assert_spawning(self)
return synchronized, (self._obj, self._lock)
@ -212,32 +218,20 @@ class SynchronizedArray(SynchronizedBase):
return len(self._obj)
def __getitem__(self, i):
self.acquire()
try:
with self:
return self._obj[i]
finally:
self.release()
def __setitem__(self, i, value):
self.acquire()
try:
with self:
self._obj[i] = value
finally:
self.release()
def __getslice__(self, start, stop):
self.acquire()
try:
with self:
return self._obj[start:stop]
finally:
self.release()
def __setslice__(self, start, stop, values):
self.acquire()
try:
with self:
self._obj[start:stop] = values
finally:
self.release()
class SynchronizedString(SynchronizedArray):

View File

@ -337,34 +337,24 @@ class Event(object):
self._flag = ctx.Semaphore(0)
def is_set(self):
self._cond.acquire()
try:
with self._cond:
if self._flag.acquire(False):
self._flag.release()
return True
return False
finally:
self._cond.release()
def set(self):
self._cond.acquire()
try:
with self._cond:
self._flag.acquire(False)
self._flag.release()
self._cond.notify_all()
finally:
self._cond.release()
def clear(self):
self._cond.acquire()
try:
with self._cond:
self._flag.acquire(False)
finally:
self._cond.release()
def wait(self, timeout=None):
self._cond.acquire()
try:
with self._cond:
if self._flag.acquire(False):
self._flag.release()
else:
@ -374,8 +364,6 @@ class Event(object):
self._flag.release()
return True
return False
finally:
self._cond.release()
#
# Barrier

View File

@ -327,6 +327,13 @@ class ForkAwareThreadLock(object):
self.acquire = self._lock.acquire
self.release = self._lock.release
def __enter__(self):
return self._lock.__enter__()
def __exit__(self, *args):
return self._lock.__exit__(*args)
class ForkAwareLocal(threading.local):
def __init__(self):
register_after_fork(self, lambda obj : obj.__dict__.clear())