asyncio: Locks refactor: use a separate context manager; remove Semaphore._locked.

This commit is contained in:
Guido van Rossum 2014-01-25 16:51:57 -08:00
parent ab27a9fc4b
commit ab3c88983b
2 changed files with 95 additions and 22 deletions

View File

@ -9,6 +9,36 @@ from . import futures
from . import tasks
class _ContextManager:
"""Context manager.
This enables the following idiom for acquiring and releasing a
lock around a block:
with (yield from lock):
<block>
while failing loudly when accidentally using:
with lock:
<block>
"""
def __init__(self, lock):
self._lock = lock
def __enter__(self):
# We have no use for the "as ..." clause in the with
# statement for locks.
return None
def __exit__(self, *args):
try:
self._lock.release()
finally:
self._lock = None # Crudely prevent reuse.
class Lock:
"""Primitive lock objects.
@ -124,17 +154,29 @@ class Lock:
raise RuntimeError('Lock is not acquired.')
def __enter__(self):
if not self._locked:
raise RuntimeError(
'"yield from" should be used as context manager expression')
return True
def __exit__(self, *args):
self.release()
# This must exist because __enter__ exists, even though that
# always raises; that's how the with-statement works.
pass
def __iter__(self):
# This is not a coroutine. It is meant to enable the idiom:
#
# with (yield from lock):
# <block>
#
# as an alternative to:
#
# yield from lock.acquire()
# try:
# <block>
# finally:
# lock.release()
yield from self.acquire()
return self
return _ContextManager(self)
class Event:
@ -311,14 +353,16 @@ class Condition:
self.notify(len(self._waiters))
def __enter__(self):
return self._lock.__enter__()
raise RuntimeError(
'"yield from" should be used as context manager expression')
def __exit__(self, *args):
return self._lock.__exit__(*args)
pass
def __iter__(self):
# See comment in Lock.__iter__().
yield from self.acquire()
return self
return _ContextManager(self)
class Semaphore:
@ -341,7 +385,6 @@ class Semaphore:
raise ValueError("Semaphore initial value must be >= 0")
self._value = value
self._waiters = collections.deque()
self._locked = (value == 0)
if loop is not None:
self._loop = loop
else:
@ -349,7 +392,7 @@ class Semaphore:
def __repr__(self):
res = super().__repr__()
extra = 'locked' if self._locked else 'unlocked,value:{}'.format(
extra = 'locked' if self.locked() else 'unlocked,value:{}'.format(
self._value)
if self._waiters:
extra = '{},waiters:{}'.format(extra, len(self._waiters))
@ -357,7 +400,7 @@ class Semaphore:
def locked(self):
"""Returns True if semaphore can not be acquired immediately."""
return self._locked
return self._value == 0
@tasks.coroutine
def acquire(self):
@ -371,8 +414,6 @@ class Semaphore:
"""
if not self._waiters and self._value > 0:
self._value -= 1
if self._value == 0:
self._locked = True
return True
fut = futures.Future(loop=self._loop)
@ -380,8 +421,6 @@ class Semaphore:
try:
yield from fut
self._value -= 1
if self._value == 0:
self._locked = True
return True
finally:
self._waiters.remove(fut)
@ -392,23 +431,22 @@ class Semaphore:
become larger than zero again, wake up that coroutine.
"""
self._value += 1
self._locked = False
for waiter in self._waiters:
if not waiter.done():
waiter.set_result(True)
break
def __enter__(self):
# TODO: This is questionable. How do we know the user actually
# wrote "with (yield from sema)" instead of "with sema"?
return True
raise RuntimeError(
'"yield from" should be used as context manager expression')
def __exit__(self, *args):
self.release()
pass
def __iter__(self):
# See comment in Lock.__iter__().
yield from self.acquire()
return self
return _ContextManager(self)
class BoundedSemaphore(Semaphore):

View File

@ -208,6 +208,24 @@ class LockTests(unittest.TestCase):
self.assertFalse(lock.locked())
def test_context_manager_cant_reuse(self):
lock = asyncio.Lock(loop=self.loop)
@asyncio.coroutine
def acquire_lock():
return (yield from lock)
# This spells "yield from lock" outside a generator.
cm = self.loop.run_until_complete(acquire_lock())
with cm:
self.assertTrue(lock.locked())
self.assertFalse(lock.locked())
with self.assertRaises(AttributeError):
with cm:
pass
def test_context_manager_no_yield(self):
lock = asyncio.Lock(loop=self.loop)
@ -219,6 +237,8 @@ class LockTests(unittest.TestCase):
str(err),
'"yield from" should be used as context manager expression')
self.assertFalse(lock.locked())
class EventTests(unittest.TestCase):
@ -655,6 +675,8 @@ class ConditionTests(unittest.TestCase):
str(err),
'"yield from" should be used as context manager expression')
self.assertFalse(cond.locked())
class SemaphoreTests(unittest.TestCase):
@ -830,6 +852,19 @@ class SemaphoreTests(unittest.TestCase):
self.assertEqual(2, sem._value)
def test_context_manager_no_yield(self):
sem = asyncio.Semaphore(2, loop=self.loop)
try:
with sem:
self.fail('RuntimeError is not raised in with expression')
except RuntimeError as err:
self.assertEqual(
str(err),
'"yield from" should be used as context manager expression')
self.assertEqual(2, sem._value)
if __name__ == '__main__':
unittest.main()