asyncio: Locks refactor: use a separate context manager; remove Semaphore._locked.
This commit is contained in:
parent
ab27a9fc4b
commit
ab3c88983b
|
@ -9,6 +9,36 @@ from . import futures
|
|||
from . import tasks
|
||||
|
||||
|
||||
class _ContextManager:
|
||||
"""Context manager.
|
||||
|
||||
This enables the following idiom for acquiring and releasing a
|
||||
lock around a block:
|
||||
|
||||
with (yield from lock):
|
||||
<block>
|
||||
|
||||
while failing loudly when accidentally using:
|
||||
|
||||
with lock:
|
||||
<block>
|
||||
"""
|
||||
|
||||
def __init__(self, lock):
|
||||
self._lock = lock
|
||||
|
||||
def __enter__(self):
|
||||
# We have no use for the "as ..." clause in the with
|
||||
# statement for locks.
|
||||
return None
|
||||
|
||||
def __exit__(self, *args):
|
||||
try:
|
||||
self._lock.release()
|
||||
finally:
|
||||
self._lock = None # Crudely prevent reuse.
|
||||
|
||||
|
||||
class Lock:
|
||||
"""Primitive lock objects.
|
||||
|
||||
|
@ -124,17 +154,29 @@ class Lock:
|
|||
raise RuntimeError('Lock is not acquired.')
|
||||
|
||||
def __enter__(self):
|
||||
if not self._locked:
|
||||
raise RuntimeError(
|
||||
'"yield from" should be used as context manager expression')
|
||||
return True
|
||||
raise RuntimeError(
|
||||
'"yield from" should be used as context manager expression')
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.release()
|
||||
# This must exist because __enter__ exists, even though that
|
||||
# always raises; that's how the with-statement works.
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
# This is not a coroutine. It is meant to enable the idiom:
|
||||
#
|
||||
# with (yield from lock):
|
||||
# <block>
|
||||
#
|
||||
# as an alternative to:
|
||||
#
|
||||
# yield from lock.acquire()
|
||||
# try:
|
||||
# <block>
|
||||
# finally:
|
||||
# lock.release()
|
||||
yield from self.acquire()
|
||||
return self
|
||||
return _ContextManager(self)
|
||||
|
||||
|
||||
class Event:
|
||||
|
@ -311,14 +353,16 @@ class Condition:
|
|||
self.notify(len(self._waiters))
|
||||
|
||||
def __enter__(self):
|
||||
return self._lock.__enter__()
|
||||
raise RuntimeError(
|
||||
'"yield from" should be used as context manager expression')
|
||||
|
||||
def __exit__(self, *args):
|
||||
return self._lock.__exit__(*args)
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
# See comment in Lock.__iter__().
|
||||
yield from self.acquire()
|
||||
return self
|
||||
return _ContextManager(self)
|
||||
|
||||
|
||||
class Semaphore:
|
||||
|
@ -341,7 +385,6 @@ class Semaphore:
|
|||
raise ValueError("Semaphore initial value must be >= 0")
|
||||
self._value = value
|
||||
self._waiters = collections.deque()
|
||||
self._locked = (value == 0)
|
||||
if loop is not None:
|
||||
self._loop = loop
|
||||
else:
|
||||
|
@ -349,7 +392,7 @@ class Semaphore:
|
|||
|
||||
def __repr__(self):
|
||||
res = super().__repr__()
|
||||
extra = 'locked' if self._locked else 'unlocked,value:{}'.format(
|
||||
extra = 'locked' if self.locked() else 'unlocked,value:{}'.format(
|
||||
self._value)
|
||||
if self._waiters:
|
||||
extra = '{},waiters:{}'.format(extra, len(self._waiters))
|
||||
|
@ -357,7 +400,7 @@ class Semaphore:
|
|||
|
||||
def locked(self):
|
||||
"""Returns True if semaphore can not be acquired immediately."""
|
||||
return self._locked
|
||||
return self._value == 0
|
||||
|
||||
@tasks.coroutine
|
||||
def acquire(self):
|
||||
|
@ -371,8 +414,6 @@ class Semaphore:
|
|||
"""
|
||||
if not self._waiters and self._value > 0:
|
||||
self._value -= 1
|
||||
if self._value == 0:
|
||||
self._locked = True
|
||||
return True
|
||||
|
||||
fut = futures.Future(loop=self._loop)
|
||||
|
@ -380,8 +421,6 @@ class Semaphore:
|
|||
try:
|
||||
yield from fut
|
||||
self._value -= 1
|
||||
if self._value == 0:
|
||||
self._locked = True
|
||||
return True
|
||||
finally:
|
||||
self._waiters.remove(fut)
|
||||
|
@ -392,23 +431,22 @@ class Semaphore:
|
|||
become larger than zero again, wake up that coroutine.
|
||||
"""
|
||||
self._value += 1
|
||||
self._locked = False
|
||||
for waiter in self._waiters:
|
||||
if not waiter.done():
|
||||
waiter.set_result(True)
|
||||
break
|
||||
|
||||
def __enter__(self):
|
||||
# TODO: This is questionable. How do we know the user actually
|
||||
# wrote "with (yield from sema)" instead of "with sema"?
|
||||
return True
|
||||
raise RuntimeError(
|
||||
'"yield from" should be used as context manager expression')
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.release()
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
# See comment in Lock.__iter__().
|
||||
yield from self.acquire()
|
||||
return self
|
||||
return _ContextManager(self)
|
||||
|
||||
|
||||
class BoundedSemaphore(Semaphore):
|
||||
|
|
|
@ -208,6 +208,24 @@ class LockTests(unittest.TestCase):
|
|||
|
||||
self.assertFalse(lock.locked())
|
||||
|
||||
def test_context_manager_cant_reuse(self):
|
||||
lock = asyncio.Lock(loop=self.loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def acquire_lock():
|
||||
return (yield from lock)
|
||||
|
||||
# This spells "yield from lock" outside a generator.
|
||||
cm = self.loop.run_until_complete(acquire_lock())
|
||||
with cm:
|
||||
self.assertTrue(lock.locked())
|
||||
|
||||
self.assertFalse(lock.locked())
|
||||
|
||||
with self.assertRaises(AttributeError):
|
||||
with cm:
|
||||
pass
|
||||
|
||||
def test_context_manager_no_yield(self):
|
||||
lock = asyncio.Lock(loop=self.loop)
|
||||
|
||||
|
@ -219,6 +237,8 @@ class LockTests(unittest.TestCase):
|
|||
str(err),
|
||||
'"yield from" should be used as context manager expression')
|
||||
|
||||
self.assertFalse(lock.locked())
|
||||
|
||||
|
||||
class EventTests(unittest.TestCase):
|
||||
|
||||
|
@ -655,6 +675,8 @@ class ConditionTests(unittest.TestCase):
|
|||
str(err),
|
||||
'"yield from" should be used as context manager expression')
|
||||
|
||||
self.assertFalse(cond.locked())
|
||||
|
||||
|
||||
class SemaphoreTests(unittest.TestCase):
|
||||
|
||||
|
@ -830,6 +852,19 @@ class SemaphoreTests(unittest.TestCase):
|
|||
|
||||
self.assertEqual(2, sem._value)
|
||||
|
||||
def test_context_manager_no_yield(self):
|
||||
sem = asyncio.Semaphore(2, loop=self.loop)
|
||||
|
||||
try:
|
||||
with sem:
|
||||
self.fail('RuntimeError is not raised in with expression')
|
||||
except RuntimeError as err:
|
||||
self.assertEqual(
|
||||
str(err),
|
||||
'"yield from" should be used as context manager expression')
|
||||
|
||||
self.assertEqual(2, sem._value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue