asyncio: Locks improvements by Arnaud Faure: better repr(), change Conditio\

n structure.
This commit is contained in:
Guido van Rossum 2013-11-04 13:18:19 -08:00
parent b58d4a3209
commit ccea08462b
2 changed files with 124 additions and 25 deletions

View File

@ -155,9 +155,11 @@ class Event:
self._loop = events.get_event_loop() self._loop = events.get_event_loop()
def __repr__(self): def __repr__(self):
# TODO: add waiters:N if > 0.
res = super().__repr__() res = super().__repr__()
return '<{} [{}]>'.format(res[1:-1], 'set' if self._value else 'unset') extra = 'set' if self._value else 'unset'
if self._waiters:
extra = '{},waiters:{}'.format(extra, len(self._waiters))
return '<{} [{}]>'.format(res[1:-1], extra)
def is_set(self): def is_set(self):
"""Return true if and only if the internal flag is true.""" """Return true if and only if the internal flag is true."""
@ -201,20 +203,38 @@ class Event:
self._waiters.remove(fut) self._waiters.remove(fut)
# TODO: Why is this a Lock subclass? threading.Condition *has* a lock. class Condition:
class Condition(Lock): """A Condition implementation, our equivalent to threading.Condition.
"""A Condition implementation.
This class implements condition variable objects. A condition variable This class implements condition variable objects. A condition variable
allows one or more coroutines to wait until they are notified by another allows one or more coroutines to wait until they are notified by another
coroutine. coroutine.
A new Lock object is created and used as the underlying lock.
""" """
def __init__(self, *, loop=None): def __init__(self, *, loop=None):
super().__init__(loop=loop) if loop is not None:
self._condition_waiters = collections.deque() self._loop = loop
else:
self._loop = events.get_event_loop()
# TODO: Add __repr__() with len(_condition_waiters). # Lock as an attribute as in threading.Condition.
lock = Lock(loop=self._loop)
self._lock = lock
# Export the lock's locked(), acquire() and release() methods.
self.locked = lock.locked
self.acquire = lock.acquire
self.release = lock.release
self._waiters = collections.deque()
def __repr__(self):
res = super().__repr__()
extra = 'locked' if self.locked() else 'unlocked'
if self._waiters:
extra = '{},waiters:{}'.format(extra, len(self._waiters))
return '<{} [{}]>'.format(res[1:-1], extra)
@tasks.coroutine @tasks.coroutine
def wait(self): def wait(self):
@ -228,19 +248,19 @@ class Condition(Lock):
the same condition variable in another coroutine. Once the same condition variable in another coroutine. Once
awakened, it re-acquires the lock and returns True. awakened, it re-acquires the lock and returns True.
""" """
if not self._locked: if not self.locked():
raise RuntimeError('cannot wait on un-acquired lock') raise RuntimeError('cannot wait on un-acquired lock')
keep_lock = True keep_lock = True
self.release() self.release()
try: try:
fut = futures.Future(loop=self._loop) fut = futures.Future(loop=self._loop)
self._condition_waiters.append(fut) self._waiters.append(fut)
try: try:
yield from fut yield from fut
return True return True
finally: finally:
self._condition_waiters.remove(fut) self._waiters.remove(fut)
except GeneratorExit: except GeneratorExit:
keep_lock = False # Prevent yield in finally clause. keep_lock = False # Prevent yield in finally clause.
@ -275,11 +295,11 @@ class Condition(Lock):
wait() call until it can reacquire the lock. Since notify() does wait() call until it can reacquire the lock. Since notify() does
not release the lock, its caller should. not release the lock, its caller should.
""" """
if not self._locked: if not self.locked():
raise RuntimeError('cannot notify on un-acquired lock') raise RuntimeError('cannot notify on un-acquired lock')
idx = 0 idx = 0
for fut in self._condition_waiters: for fut in self._waiters:
if idx >= n: if idx >= n:
break break
@ -293,7 +313,17 @@ class Condition(Lock):
calling thread has not acquired the lock when this method is called, calling thread has not acquired the lock when this method is called,
a RuntimeError is raised. a RuntimeError is raised.
""" """
self.notify(len(self._condition_waiters)) self.notify(len(self._waiters))
def __enter__(self):
return self._lock.__enter__()
def __exit__(self, *args):
return self._lock.__exit__(*args)
def __iter__(self):
yield from self.acquire()
return self
class Semaphore: class Semaphore:
@ -310,10 +340,10 @@ class Semaphore:
counter; it defaults to 1. If the value given is less than 0, counter; it defaults to 1. If the value given is less than 0,
ValueError is raised. ValueError is raised.
The second optional argument determins can semophore be released more than The second optional argument determines if the semaphore can be released
initial internal counter value; it defaults to False. If the value given more than initial internal counter value; it defaults to False. If the
is True and number of release() is more than number of successfull value given is True and number of release() is more than number of
acquire() calls ValueError is raised. successful acquire() calls ValueError is raised.
""" """
def __init__(self, value=1, bound=False, *, loop=None): def __init__(self, value=1, bound=False, *, loop=None):
@ -330,12 +360,12 @@ class Semaphore:
self._loop = events.get_event_loop() self._loop = events.get_event_loop()
def __repr__(self): def __repr__(self):
# TODO: add waiters:N if > 0.
res = super().__repr__() res = super().__repr__()
return '<{} [{}]>'.format( extra = 'locked' if self._locked else 'unlocked,value:{}'.format(
res[1:-1], self._value)
'locked' if self._locked else 'unlocked,value:{}'.format( if self._waiters:
self._value)) extra = '{},waiters:{}'.format(extra, len(self._waiters))
return '<{} [{}]>'.format(res[1:-1], extra)
def locked(self): def locked(self):
"""Returns True if semaphore can not be acquired immediately.""" """Returns True if semaphore can not be acquired immediately."""
@ -373,7 +403,7 @@ class Semaphore:
When it was zero on entry and another coroutine is waiting for it to When it was zero on entry and another coroutine is waiting for it to
become larger than zero again, wake up that coroutine. become larger than zero again, wake up that coroutine.
If Semaphore is create with "bound" paramter equals true, then If Semaphore is created with "bound" parameter equals true, then
release() method checks to make sure its current value doesn't exceed release() method checks to make sure its current value doesn't exceed
its initial value. If it does, ValueError is raised. its initial value. If it does, ValueError is raised.
""" """

View File

@ -2,6 +2,7 @@
import unittest import unittest
import unittest.mock import unittest.mock
import re
from asyncio import events from asyncio import events
from asyncio import futures from asyncio import futures
@ -10,6 +11,15 @@ from asyncio import tasks
from asyncio import test_utils from asyncio import test_utils
STR_RGX_REPR = (
r'^<(?P<class>.*?) object at (?P<address>.*?)'
r'\[(?P<extras>'
r'(set|unset|locked|unlocked)(,value:\d)?(,waiters:\d+)?'
r')\]>\Z'
)
RGX_REPR = re.compile(STR_RGX_REPR)
class LockTests(unittest.TestCase): class LockTests(unittest.TestCase):
def setUp(self): def setUp(self):
@ -38,6 +48,7 @@ class LockTests(unittest.TestCase):
def test_repr(self): def test_repr(self):
lock = locks.Lock(loop=self.loop) lock = locks.Lock(loop=self.loop)
self.assertTrue(repr(lock).endswith('[unlocked]>')) self.assertTrue(repr(lock).endswith('[unlocked]>'))
self.assertTrue(RGX_REPR.match(repr(lock)))
@tasks.coroutine @tasks.coroutine
def acquire_lock(): def acquire_lock():
@ -45,6 +56,7 @@ class LockTests(unittest.TestCase):
self.loop.run_until_complete(acquire_lock()) self.loop.run_until_complete(acquire_lock())
self.assertTrue(repr(lock).endswith('[locked]>')) self.assertTrue(repr(lock).endswith('[locked]>'))
self.assertTrue(RGX_REPR.match(repr(lock)))
def test_lock(self): def test_lock(self):
lock = locks.Lock(loop=self.loop) lock = locks.Lock(loop=self.loop)
@ -239,9 +251,16 @@ class EventTests(unittest.TestCase):
def test_repr(self): def test_repr(self):
ev = locks.Event(loop=self.loop) ev = locks.Event(loop=self.loop)
self.assertTrue(repr(ev).endswith('[unset]>')) self.assertTrue(repr(ev).endswith('[unset]>'))
match = RGX_REPR.match(repr(ev))
self.assertEqual(match.group('extras'), 'unset')
ev.set() ev.set()
self.assertTrue(repr(ev).endswith('[set]>')) self.assertTrue(repr(ev).endswith('[set]>'))
self.assertTrue(RGX_REPR.match(repr(ev)))
ev._waiters.append(unittest.mock.Mock())
self.assertTrue('waiters:1' in repr(ev))
self.assertTrue(RGX_REPR.match(repr(ev)))
def test_wait(self): def test_wait(self):
ev = locks.Event(loop=self.loop) ev = locks.Event(loop=self.loop)
@ -440,7 +459,7 @@ class ConditionTests(unittest.TestCase):
self.assertRaises( self.assertRaises(
futures.CancelledError, futures.CancelledError,
self.loop.run_until_complete, wait) self.loop.run_until_complete, wait)
self.assertFalse(cond._condition_waiters) self.assertFalse(cond._waiters)
self.assertTrue(cond.locked()) self.assertTrue(cond.locked())
def test_wait_unacquired(self): def test_wait_unacquired(self):
@ -600,6 +619,45 @@ class ConditionTests(unittest.TestCase):
cond = locks.Condition(loop=self.loop) cond = locks.Condition(loop=self.loop)
self.assertRaises(RuntimeError, cond.notify_all) self.assertRaises(RuntimeError, cond.notify_all)
def test_repr(self):
cond = locks.Condition(loop=self.loop)
self.assertTrue('unlocked' in repr(cond))
self.assertTrue(RGX_REPR.match(repr(cond)))
self.loop.run_until_complete(cond.acquire())
self.assertTrue('locked' in repr(cond))
cond._waiters.append(unittest.mock.Mock())
self.assertTrue('waiters:1' in repr(cond))
self.assertTrue(RGX_REPR.match(repr(cond)))
cond._waiters.append(unittest.mock.Mock())
self.assertTrue('waiters:2' in repr(cond))
self.assertTrue(RGX_REPR.match(repr(cond)))
def test_context_manager(self):
cond = locks.Condition(loop=self.loop)
@tasks.coroutine
def acquire_cond():
return (yield from cond)
with self.loop.run_until_complete(acquire_cond()):
self.assertTrue(cond.locked())
self.assertFalse(cond.locked())
def test_context_manager_no_yield(self):
cond = locks.Condition(loop=self.loop)
try:
with cond:
self.fail('RuntimeError is not raised in with expression')
except RuntimeError as err:
self.assertEqual(
str(err),
'"yield from" should be used as context manager expression')
class SemaphoreTests(unittest.TestCase): class SemaphoreTests(unittest.TestCase):
@ -629,9 +687,20 @@ class SemaphoreTests(unittest.TestCase):
def test_repr(self): def test_repr(self):
sem = locks.Semaphore(loop=self.loop) sem = locks.Semaphore(loop=self.loop)
self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) self.assertTrue(repr(sem).endswith('[unlocked,value:1]>'))
self.assertTrue(RGX_REPR.match(repr(sem)))
self.loop.run_until_complete(sem.acquire()) self.loop.run_until_complete(sem.acquire())
self.assertTrue(repr(sem).endswith('[locked]>')) self.assertTrue(repr(sem).endswith('[locked]>'))
self.assertTrue('waiters' not in repr(sem))
self.assertTrue(RGX_REPR.match(repr(sem)))
sem._waiters.append(unittest.mock.Mock())
self.assertTrue('waiters:1' in repr(sem))
self.assertTrue(RGX_REPR.match(repr(sem)))
sem._waiters.append(unittest.mock.Mock())
self.assertTrue('waiters:2' in repr(sem))
self.assertTrue(RGX_REPR.match(repr(sem)))
def test_semaphore(self): def test_semaphore(self):
sem = locks.Semaphore(loop=self.loop) sem = locks.Semaphore(loop=self.loop)