asyncio: Change bounded semaphore into a subclass, like threading.[Bounded]Semaphore.

This commit is contained in:
Guido van Rossum 2013-11-23 15:09:16 -08:00
parent dcd340eeeb
commit 085869bfee
2 changed files with 20 additions and 18 deletions

View File

@ -336,22 +336,15 @@ class Semaphore:
Semaphores also support the context manager protocol.
The first optional argument gives the initial value for the internal
The optional argument gives the initial value for the internal
counter; it defaults to 1. If the value given is less than 0,
ValueError is raised.
The second optional argument determines if the semaphore can be released
more than initial internal counter value; it defaults to False. If the
value given is True and number of release() is more than number of
successful acquire() calls ValueError is raised.
"""
def __init__(self, value=1, bound=False, *, loop=None):
def __init__(self, value=1, *, loop=None):
if value < 0:
raise ValueError("Semaphore initial value must be >= 0")
self._value = value
self._bound = bound
self._bound_value = value
self._waiters = collections.deque()
self._locked = (value == 0)
if loop is not None:
@ -402,17 +395,9 @@ class Semaphore:
"""Release a semaphore, incrementing the internal counter by one.
When it was zero on entry and another coroutine is waiting for it to
become larger than zero again, wake up that coroutine.
If Semaphore is created with "bound" parameter equals true, then
release() method checks to make sure its current value doesn't exceed
its initial value. If it does, ValueError is raised.
"""
if self._bound and self._value >= self._bound_value:
raise ValueError('Semaphore released too many times')
self._value += 1
self._locked = False
for waiter in self._waiters:
if not waiter.done():
waiter.set_result(True)
@ -429,3 +414,20 @@ class Semaphore:
def __iter__(self):
yield from self.acquire()
return self
class BoundedSemaphore(Semaphore):
"""A bounded semaphore implementation.
This raises ValueError in release() if it would increase the value
above the initial value.
"""
def __init__(self, value=1, *, loop=None):
self._bound_value = value
super().__init__(value, loop=loop)
def release(self):
if self._value >= self._bound_value:
raise ValueError('BoundedSemaphore released too many times')
super().release()

View File

@ -805,7 +805,7 @@ class SemaphoreTests(unittest.TestCase):
self.assertFalse(sem._waiters)
def test_release_not_acquired(self):
sem = locks.Semaphore(bound=True, loop=self.loop)
sem = locks.BoundedSemaphore(loop=self.loop)
self.assertRaises(ValueError, sem.release)