mirror of https://github.com/python/cpython
bpo-43352: Add a Barrier object in asyncio lib (GH-24903)
Co-authored-by: Yury Selivanov <yury@edgedb.com> Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
This commit is contained in:
parent
20e6e5636a
commit
d03acd7270
|
@ -186,11 +186,16 @@ Threading-like synchronization primitives that can be used in Tasks.
|
|||
* - :class:`BoundedSemaphore`
|
||||
- A bounded semaphore.
|
||||
|
||||
* - :class:`Barrier`
|
||||
- A barrier object.
|
||||
|
||||
|
||||
.. rubric:: Examples
|
||||
|
||||
* :ref:`Using asyncio.Event <asyncio_example_sync_event>`.
|
||||
|
||||
* :ref:`Using asyncio.Barrier <asyncio_example_barrier>`.
|
||||
|
||||
* See also the documentation of asyncio
|
||||
:ref:`synchronization primitives <asyncio-sync>`.
|
||||
|
||||
|
@ -206,6 +211,9 @@ Exceptions
|
|||
* - :exc:`asyncio.CancelledError`
|
||||
- Raised when a Task is cancelled. See also :meth:`Task.cancel`.
|
||||
|
||||
* - :exc:`asyncio.BrokenBarrierError`
|
||||
- Raised when a Barrier is broken. See also :meth:`Barrier.wait`.
|
||||
|
||||
|
||||
.. rubric:: Examples
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ asyncio has the following basic synchronization primitives:
|
|||
* :class:`Condition`
|
||||
* :class:`Semaphore`
|
||||
* :class:`BoundedSemaphore`
|
||||
* :class:`Barrier`
|
||||
|
||||
|
||||
---------
|
||||
|
@ -340,6 +341,115 @@ BoundedSemaphore
|
|||
.. versionchanged:: 3.10
|
||||
Removed the *loop* parameter.
|
||||
|
||||
|
||||
Barrier
|
||||
=======
|
||||
|
||||
.. class:: Barrier(parties, action=None)
|
||||
|
||||
A barrier object. Not thread-safe.
|
||||
|
||||
A barrier is a simple synchronization primitive that allows to block until
|
||||
*parties* number of tasks are waiting on it.
|
||||
Tasks can wait on the :meth:`~Barrier.wait` method and would be blocked until
|
||||
the specified number of tasks end up waiting on :meth:`~Barrier.wait`.
|
||||
At that point all of the waiting tasks would unblock simultaneously.
|
||||
|
||||
:keyword:`async with` can be used as an alternative to awaiting on
|
||||
:meth:`~Barrier.wait`.
|
||||
|
||||
The barrier can be reused any number of times.
|
||||
|
||||
.. _asyncio_example_barrier:
|
||||
|
||||
Example::
|
||||
|
||||
async def example_barrier():
|
||||
# barrier with 3 parties
|
||||
b = asyncio.Barrier(3)
|
||||
|
||||
# create 2 new waiting tasks
|
||||
asyncio.create_task(b.wait())
|
||||
asyncio.create_task(b.wait())
|
||||
|
||||
await asyncio.sleep(0)
|
||||
print(b)
|
||||
|
||||
# The third .wait() call passes the barrier
|
||||
await b.wait()
|
||||
print(b)
|
||||
print("barrier passed")
|
||||
|
||||
await asyncio.sleep(0)
|
||||
print(b)
|
||||
|
||||
asyncio.run(example_barrier())
|
||||
|
||||
Result of this example is::
|
||||
|
||||
<asyncio.locks.Barrier object at 0x... [filling, waiters:2/3]>
|
||||
<asyncio.locks.Barrier object at 0x... [draining, waiters:0/3]>
|
||||
barrier passed
|
||||
<asyncio.locks.Barrier object at 0x... [filling, waiters:0/3]>
|
||||
|
||||
.. versionadded:: 3.11
|
||||
|
||||
.. coroutinemethod:: wait()
|
||||
|
||||
Pass the barrier. When all the tasks party to the barrier have called
|
||||
this function, they are all unblocked simultaneously.
|
||||
|
||||
When a waiting or blocked task in the barrier is cancelled,
|
||||
this task exits the barrier which stays in the same state.
|
||||
If the state of the barrier is "filling", the number of waiting task
|
||||
decreases by 1.
|
||||
|
||||
The return value is an integer in the range of 0 to ``parties-1``, different
|
||||
for each task. This can be used to select a task to do some special
|
||||
housekeeping, e.g.::
|
||||
|
||||
...
|
||||
async with barrier as position:
|
||||
if position == 0:
|
||||
# Only one task print this
|
||||
print('End of *draining phasis*')
|
||||
|
||||
This method may raise a :class:`BrokenBarrierError` exception if the
|
||||
barrier is broken or reset while a task is waiting.
|
||||
It could raise a :exc:`CancelledError` if a task is cancelled.
|
||||
|
||||
.. coroutinemethod:: reset()
|
||||
|
||||
Return the barrier to the default, empty state. Any tasks waiting on it
|
||||
will receive the :class:`BrokenBarrierError` exception.
|
||||
|
||||
If a barrier is broken it may be better to just leave it and create a new one.
|
||||
|
||||
.. coroutinemethod:: abort()
|
||||
|
||||
Put the barrier into a broken state. This causes any active or future
|
||||
calls to :meth:`wait` to fail with the :class:`BrokenBarrierError`.
|
||||
Use this for example if one of the taks needs to abort, to avoid infinite
|
||||
waiting tasks.
|
||||
|
||||
.. attribute:: parties
|
||||
|
||||
The number of tasks required to pass the barrier.
|
||||
|
||||
.. attribute:: n_waiting
|
||||
|
||||
The number of tasks currently waiting in the barrier while filling.
|
||||
|
||||
.. attribute:: broken
|
||||
|
||||
A boolean that is ``True`` if the barrier is in the broken state.
|
||||
|
||||
|
||||
.. exception:: BrokenBarrierError
|
||||
|
||||
This exception, a subclass of :exc:`RuntimeError`, is raised when the
|
||||
:class:`Barrier` object is reset or broken.
|
||||
|
||||
---------
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
"""asyncio exceptions."""
|
||||
|
||||
|
||||
__all__ = ('CancelledError', 'InvalidStateError', 'TimeoutError',
|
||||
__all__ = ('BrokenBarrierError',
|
||||
'CancelledError', 'InvalidStateError', 'TimeoutError',
|
||||
'IncompleteReadError', 'LimitOverrunError',
|
||||
'SendfileNotAvailableError')
|
||||
|
||||
|
@ -55,3 +56,7 @@ class LimitOverrunError(Exception):
|
|||
|
||||
def __reduce__(self):
|
||||
return type(self), (self.args[0], self.consumed)
|
||||
|
||||
|
||||
class BrokenBarrierError(RuntimeError):
|
||||
"""Barrier is broken by barrier.abort() call."""
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
"""Synchronization primitives."""
|
||||
|
||||
__all__ = ('Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore')
|
||||
__all__ = ('Lock', 'Event', 'Condition', 'Semaphore',
|
||||
'BoundedSemaphore', 'Barrier')
|
||||
|
||||
import collections
|
||||
import enum
|
||||
|
||||
from . import exceptions
|
||||
from . import mixins
|
||||
from . import tasks
|
||||
|
||||
|
||||
class _ContextManagerMixin:
|
||||
async def __aenter__(self):
|
||||
await self.acquire()
|
||||
|
@ -416,3 +417,155 @@ class BoundedSemaphore(Semaphore):
|
|||
if self._value >= self._bound_value:
|
||||
raise ValueError('BoundedSemaphore released too many times')
|
||||
super().release()
|
||||
|
||||
|
||||
|
||||
class _BarrierState(enum.Enum):
|
||||
FILLING = 'filling'
|
||||
DRAINING = 'draining'
|
||||
RESETTING = 'resetting'
|
||||
BROKEN = 'broken'
|
||||
|
||||
|
||||
class Barrier(mixins._LoopBoundMixin):
|
||||
"""Asyncio equivalent to threading.Barrier
|
||||
|
||||
Implements a Barrier primitive.
|
||||
Useful for synchronizing a fixed number of tasks at known synchronization
|
||||
points. Tasks block on 'wait()' and are simultaneously awoken once they
|
||||
have all made their call.
|
||||
"""
|
||||
|
||||
def __init__(self, parties):
|
||||
"""Create a barrier, initialised to 'parties' tasks."""
|
||||
if parties < 1:
|
||||
raise ValueError('parties must be > 0')
|
||||
|
||||
self._cond = Condition() # notify all tasks when state changes
|
||||
|
||||
self._parties = parties
|
||||
self._state = _BarrierState.FILLING
|
||||
self._count = 0 # count tasks in Barrier
|
||||
|
||||
def __repr__(self):
|
||||
res = super().__repr__()
|
||||
extra = f'{self._state.value}'
|
||||
if not self.broken:
|
||||
extra += f', waiters:{self.n_waiting}/{self.parties}'
|
||||
return f'<{res[1:-1]} [{extra}]>'
|
||||
|
||||
async def __aenter__(self):
|
||||
# wait for the barrier reaches the parties number
|
||||
# when start draining release and return index of waited task
|
||||
return await self.wait()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
async def wait(self):
|
||||
"""Wait for the barrier.
|
||||
|
||||
When the specified number of tasks have started waiting, they are all
|
||||
simultaneously awoken.
|
||||
Returns an unique and individual index number from 0 to 'parties-1'.
|
||||
"""
|
||||
async with self._cond:
|
||||
await self._block() # Block while the barrier drains or resets.
|
||||
try:
|
||||
index = self._count
|
||||
self._count += 1
|
||||
if index + 1 == self._parties:
|
||||
# We release the barrier
|
||||
await self._release()
|
||||
else:
|
||||
await self._wait()
|
||||
return index
|
||||
finally:
|
||||
self._count -= 1
|
||||
# Wake up any tasks waiting for barrier to drain.
|
||||
self._exit()
|
||||
|
||||
async def _block(self):
|
||||
# Block until the barrier is ready for us,
|
||||
# or raise an exception if it is broken.
|
||||
#
|
||||
# It is draining or resetting, wait until done
|
||||
# unless a CancelledError occurs
|
||||
await self._cond.wait_for(
|
||||
lambda: self._state not in (
|
||||
_BarrierState.DRAINING, _BarrierState.RESETTING
|
||||
)
|
||||
)
|
||||
|
||||
# see if the barrier is in a broken state
|
||||
if self._state is _BarrierState.BROKEN:
|
||||
raise exceptions.BrokenBarrierError("Barrier aborted")
|
||||
|
||||
async def _release(self):
|
||||
# Release the tasks waiting in the barrier.
|
||||
|
||||
# Enter draining state.
|
||||
# Next waiting tasks will be blocked until the end of draining.
|
||||
self._state = _BarrierState.DRAINING
|
||||
self._cond.notify_all()
|
||||
|
||||
async def _wait(self):
|
||||
# Wait in the barrier until we are released. Raise an exception
|
||||
# if the barrier is reset or broken.
|
||||
|
||||
# wait for end of filling
|
||||
# unless a CancelledError occurs
|
||||
await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING)
|
||||
|
||||
if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING):
|
||||
raise exceptions.BrokenBarrierError("Abort or reset of barrier")
|
||||
|
||||
def _exit(self):
|
||||
# If we are the last tasks to exit the barrier, signal any tasks
|
||||
# waiting for the barrier to drain.
|
||||
if self._count == 0:
|
||||
if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING):
|
||||
self._state = _BarrierState.FILLING
|
||||
self._cond.notify_all()
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the barrier to the initial state.
|
||||
|
||||
Any tasks currently waiting will get the BrokenBarrier exception
|
||||
raised.
|
||||
"""
|
||||
async with self._cond:
|
||||
if self._count > 0:
|
||||
if self._state is not _BarrierState.RESETTING:
|
||||
#reset the barrier, waking up tasks
|
||||
self._state = _BarrierState.RESETTING
|
||||
else:
|
||||
self._state = _BarrierState.FILLING
|
||||
self._cond.notify_all()
|
||||
|
||||
async def abort(self):
|
||||
"""Place the barrier into a 'broken' state.
|
||||
|
||||
Useful in case of error. Any currently waiting tasks and tasks
|
||||
attempting to 'wait()' will have BrokenBarrierError raised.
|
||||
"""
|
||||
async with self._cond:
|
||||
self._state = _BarrierState.BROKEN
|
||||
self._cond.notify_all()
|
||||
|
||||
@property
|
||||
def parties(self):
|
||||
"""Return the number of tasks required to trip the barrier."""
|
||||
return self._parties
|
||||
|
||||
@property
|
||||
def n_waiting(self):
|
||||
"""Return the number of tasks currently waiting at the barrier."""
|
||||
if self._state is _BarrierState.FILLING:
|
||||
return self._count
|
||||
return 0
|
||||
|
||||
@property
|
||||
def broken(self):
|
||||
"""Return True if the barrier is in a broken state."""
|
||||
return self._state is _BarrierState.BROKEN
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
"""Tests for lock.py"""
|
||||
"""Tests for locks.py"""
|
||||
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
@ -9,7 +9,10 @@ import asyncio
|
|||
STR_RGX_REPR = (
|
||||
r'^<(?P<class>.*?) object at (?P<address>.*?)'
|
||||
r'\[(?P<extras>'
|
||||
r'(set|unset|locked|unlocked)(, value:\d)?(, waiters:\d+)?'
|
||||
r'(set|unset|locked|unlocked|filling|draining|resetting|broken)'
|
||||
r'(, value:\d)?'
|
||||
r'(, waiters:\d+)?'
|
||||
r'(, waiters:\d+\/\d+)?' # barrier
|
||||
r')\]>\Z'
|
||||
)
|
||||
RGX_REPR = re.compile(STR_RGX_REPR)
|
||||
|
@ -943,5 +946,576 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
|
|||
)
|
||||
|
||||
|
||||
class BarrierTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.N = 5
|
||||
|
||||
def make_tasks(self, n, coro):
|
||||
tasks = [asyncio.create_task(coro()) for _ in range(n)]
|
||||
return tasks
|
||||
|
||||
async def gather_tasks(self, n, coro):
|
||||
tasks = self.make_tasks(n, coro)
|
||||
res = await asyncio.gather(*tasks)
|
||||
return res, tasks
|
||||
|
||||
async def test_barrier(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
self.assertIn("filling", repr(barrier))
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"object Barrier can't be used in 'await' expression",
|
||||
):
|
||||
await barrier
|
||||
|
||||
self.assertIn("filling", repr(barrier))
|
||||
|
||||
async def test_repr(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
|
||||
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
||||
self.assertIn("filling", repr(barrier))
|
||||
|
||||
waiters = []
|
||||
async def wait(barrier):
|
||||
await barrier.wait()
|
||||
|
||||
incr = 2
|
||||
for i in range(incr):
|
||||
waiters.append(asyncio.create_task(wait(barrier)))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
||||
self.assertTrue(f"waiters:{incr}/{self.N}" in repr(barrier))
|
||||
self.assertIn("filling", repr(barrier))
|
||||
|
||||
# create missing waiters
|
||||
for i in range(barrier.parties - barrier.n_waiting):
|
||||
waiters.append(asyncio.create_task(wait(barrier)))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
||||
self.assertIn("draining", repr(barrier))
|
||||
|
||||
# add a part of waiters
|
||||
for i in range(incr):
|
||||
waiters.append(asyncio.create_task(wait(barrier)))
|
||||
await asyncio.sleep(0)
|
||||
# and reset
|
||||
await barrier.reset()
|
||||
|
||||
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
||||
self.assertIn("resetting", repr(barrier))
|
||||
|
||||
# add a part of waiters again
|
||||
for i in range(incr):
|
||||
waiters.append(asyncio.create_task(wait(barrier)))
|
||||
await asyncio.sleep(0)
|
||||
# and abort
|
||||
await barrier.abort()
|
||||
|
||||
self.assertTrue(RGX_REPR.match(repr(barrier)))
|
||||
self.assertIn("broken", repr(barrier))
|
||||
self.assertTrue(barrier.broken)
|
||||
|
||||
# suppress unhandled exceptions
|
||||
await asyncio.gather(*waiters, return_exceptions=True)
|
||||
|
||||
async def test_barrier_parties(self):
|
||||
self.assertRaises(ValueError, lambda: asyncio.Barrier(0))
|
||||
self.assertRaises(ValueError, lambda: asyncio.Barrier(-4))
|
||||
|
||||
self.assertIsInstance(asyncio.Barrier(self.N), asyncio.Barrier)
|
||||
|
||||
async def test_context_manager(self):
|
||||
self.N = 3
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results = []
|
||||
|
||||
async def coro():
|
||||
async with barrier as i:
|
||||
results.append(i)
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertListEqual(sorted(results), list(range(self.N)))
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_filling_one_task(self):
|
||||
barrier = asyncio.Barrier(1)
|
||||
|
||||
async def f():
|
||||
async with barrier as i:
|
||||
return True
|
||||
|
||||
ret = await f()
|
||||
|
||||
self.assertTrue(ret)
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_filling_one_task_twice(self):
|
||||
barrier = asyncio.Barrier(1)
|
||||
|
||||
t1 = asyncio.create_task(barrier.wait())
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
|
||||
t2 = asyncio.create_task(barrier.wait())
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.assertEqual(t1.result(), t2.result())
|
||||
self.assertEqual(t1.done(), t2.done())
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_filling_task_by_task(self):
|
||||
self.N = 3
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
|
||||
t1 = asyncio.create_task(barrier.wait())
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual(barrier.n_waiting, 1)
|
||||
self.assertIn("filling", repr(barrier))
|
||||
|
||||
t2 = asyncio.create_task(barrier.wait())
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual(barrier.n_waiting, 2)
|
||||
self.assertIn("filling", repr(barrier))
|
||||
|
||||
t3 = asyncio.create_task(barrier.wait())
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await asyncio.wait([t1, t2, t3])
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_filling_tasks_wait_twice(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results = []
|
||||
|
||||
async def coro():
|
||||
async with barrier:
|
||||
results.append(True)
|
||||
|
||||
async with barrier:
|
||||
results.append(False)
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertEqual(len(results), self.N*2)
|
||||
self.assertEqual(results.count(True), self.N)
|
||||
self.assertEqual(results.count(False), self.N)
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_filling_tasks_check_return_value(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results1 = []
|
||||
results2 = []
|
||||
|
||||
async def coro():
|
||||
async with barrier:
|
||||
results1.append(True)
|
||||
|
||||
async with barrier as i:
|
||||
results2.append(True)
|
||||
return i
|
||||
|
||||
res, _ = await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertEqual(len(results1), self.N)
|
||||
self.assertTrue(all(results1))
|
||||
self.assertEqual(len(results2), self.N)
|
||||
self.assertTrue(all(results2))
|
||||
self.assertListEqual(sorted(res), list(range(self.N)))
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_draining_state(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results = []
|
||||
|
||||
async def coro():
|
||||
async with barrier:
|
||||
# barrier state change to filling for the last task release
|
||||
results.append("draining" in repr(barrier))
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertEqual(len(results), self.N)
|
||||
self.assertEqual(results[-1], False)
|
||||
self.assertTrue(all(results[:self.N-1]))
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_blocking_tasks_while_draining(self):
|
||||
rewait = 2
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
barrier_nowaiting = asyncio.Barrier(self.N - rewait)
|
||||
results = []
|
||||
rewait_n = rewait
|
||||
counter = 0
|
||||
|
||||
async def coro():
|
||||
nonlocal rewait_n
|
||||
|
||||
# first time waiting
|
||||
await barrier.wait()
|
||||
|
||||
# after wainting once for all tasks
|
||||
if rewait_n > 0:
|
||||
rewait_n -= 1
|
||||
# wait again only for rewait tasks
|
||||
await barrier.wait()
|
||||
else:
|
||||
# wait for end of draining state`
|
||||
await barrier_nowaiting.wait()
|
||||
# wait for other waiting tasks
|
||||
await barrier.wait()
|
||||
|
||||
# a success means that barrier_nowaiting
|
||||
# was waited for exactly N-rewait=3 times
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
async def test_filling_tasks_cancel_one(self):
|
||||
self.N = 3
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results = []
|
||||
|
||||
async def coro():
|
||||
await barrier.wait()
|
||||
results.append(True)
|
||||
|
||||
t1 = asyncio.create_task(coro())
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual(barrier.n_waiting, 1)
|
||||
|
||||
t2 = asyncio.create_task(coro())
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual(barrier.n_waiting, 2)
|
||||
|
||||
t1.cancel()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual(barrier.n_waiting, 1)
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await t1
|
||||
self.assertTrue(t1.cancelled())
|
||||
|
||||
t3 = asyncio.create_task(coro())
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual(barrier.n_waiting, 2)
|
||||
|
||||
t4 = asyncio.create_task(coro())
|
||||
await asyncio.gather(t2, t3, t4)
|
||||
|
||||
self.assertEqual(len(results), self.N)
|
||||
self.assertTrue(all(results))
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_reset_barrier(self):
|
||||
barrier = asyncio.Barrier(1)
|
||||
|
||||
asyncio.create_task(barrier.reset())
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_reset_barrier_while_tasks_waiting(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results = []
|
||||
|
||||
async def coro():
|
||||
try:
|
||||
await barrier.wait()
|
||||
except asyncio.BrokenBarrierError:
|
||||
results.append(True)
|
||||
|
||||
async def coro_reset():
|
||||
await barrier.reset()
|
||||
|
||||
# N-1 tasks waiting on barrier with N parties
|
||||
tasks = self.make_tasks(self.N-1, coro)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# reset the barrier
|
||||
asyncio.create_task(coro_reset())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
self.assertEqual(len(results), self.N-1)
|
||||
self.assertTrue(all(results))
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertNotIn("resetting", repr(barrier))
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_reset_barrier_when_tasks_half_draining(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results1 = []
|
||||
rest_of_tasks = self.N//2
|
||||
|
||||
async def coro():
|
||||
try:
|
||||
await barrier.wait()
|
||||
except asyncio.BrokenBarrierError:
|
||||
# catch here waiting tasks
|
||||
results1.append(True)
|
||||
else:
|
||||
# here drained task ouside the barrier
|
||||
if rest_of_tasks == barrier._count:
|
||||
# tasks outside the barrier
|
||||
await barrier.reset()
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertEqual(results1, [True]*rest_of_tasks)
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertNotIn("resetting", repr(barrier))
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_reset_barrier_when_tasks_half_draining_half_blocking(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results1 = []
|
||||
results2 = []
|
||||
blocking_tasks = self.N//2
|
||||
count = 0
|
||||
|
||||
async def coro():
|
||||
nonlocal count
|
||||
try:
|
||||
await barrier.wait()
|
||||
except asyncio.BrokenBarrierError:
|
||||
# here catch still waiting tasks
|
||||
results1.append(True)
|
||||
|
||||
# so now waiting again to reach nb_parties
|
||||
await barrier.wait()
|
||||
else:
|
||||
count += 1
|
||||
if count > blocking_tasks:
|
||||
# reset now: raise asyncio.BrokenBarrierError for waiting tasks
|
||||
await barrier.reset()
|
||||
|
||||
# so now waiting again to reach nb_parties
|
||||
await barrier.wait()
|
||||
else:
|
||||
try:
|
||||
await barrier.wait()
|
||||
except asyncio.BrokenBarrierError:
|
||||
# here no catch - blocked tasks go to wait
|
||||
results2.append(True)
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertEqual(results1, [True]*blocking_tasks)
|
||||
self.assertEqual(results2, [])
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertNotIn("resetting", repr(barrier))
|
||||
self.assertFalse(barrier.broken)
|
||||
|
||||
async def test_reset_barrier_while_tasks_waiting_and_waiting_again(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results1 = []
|
||||
results2 = []
|
||||
|
||||
async def coro1():
|
||||
try:
|
||||
await barrier.wait()
|
||||
except asyncio.BrokenBarrierError:
|
||||
results1.append(True)
|
||||
finally:
|
||||
await barrier.wait()
|
||||
results2.append(True)
|
||||
|
||||
async def coro2():
|
||||
async with barrier:
|
||||
results2.append(True)
|
||||
|
||||
tasks = self.make_tasks(self.N-1, coro1)
|
||||
|
||||
# reset barrier, N-1 waiting tasks raise an BrokenBarrierError
|
||||
asyncio.create_task(barrier.reset())
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# complete waiting tasks in the `finally`
|
||||
asyncio.create_task(coro2())
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
self.assertFalse(barrier.broken)
|
||||
self.assertEqual(len(results1), self.N-1)
|
||||
self.assertTrue(all(results1))
|
||||
self.assertEqual(len(results2), self.N)
|
||||
self.assertTrue(all(results2))
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
|
||||
|
||||
async def test_reset_barrier_while_tasks_draining(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results1 = []
|
||||
results2 = []
|
||||
results3 = []
|
||||
count = 0
|
||||
|
||||
async def coro():
|
||||
nonlocal count
|
||||
|
||||
i = await barrier.wait()
|
||||
count += 1
|
||||
if count == self.N:
|
||||
# last task exited from barrier
|
||||
await barrier.reset()
|
||||
|
||||
# wit here to reach the `parties`
|
||||
await barrier.wait()
|
||||
else:
|
||||
try:
|
||||
# second waiting
|
||||
await barrier.wait()
|
||||
|
||||
# N-1 tasks here
|
||||
results1.append(True)
|
||||
except Exception as e:
|
||||
# never goes here
|
||||
results2.append(True)
|
||||
|
||||
# Now, pass the barrier again
|
||||
# last wait, must be completed
|
||||
k = await barrier.wait()
|
||||
results3.append(True)
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertFalse(barrier.broken)
|
||||
self.assertTrue(all(results1))
|
||||
self.assertEqual(len(results1), self.N-1)
|
||||
self.assertEqual(len(results2), 0)
|
||||
self.assertEqual(len(results3), self.N)
|
||||
self.assertTrue(all(results3))
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
|
||||
async def test_abort_barrier(self):
|
||||
barrier = asyncio.Barrier(1)
|
||||
|
||||
asyncio.create_task(barrier.abort())
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertTrue(barrier.broken)
|
||||
|
||||
async def test_abort_barrier_when_tasks_half_draining_half_blocking(self):
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results1 = []
|
||||
results2 = []
|
||||
blocking_tasks = self.N//2
|
||||
count = 0
|
||||
|
||||
async def coro():
|
||||
nonlocal count
|
||||
try:
|
||||
await barrier.wait()
|
||||
except asyncio.BrokenBarrierError:
|
||||
# here catch tasks waiting to drain
|
||||
results1.append(True)
|
||||
else:
|
||||
count += 1
|
||||
if count > blocking_tasks:
|
||||
# abort now: raise asyncio.BrokenBarrierError for all tasks
|
||||
await barrier.abort()
|
||||
else:
|
||||
try:
|
||||
await barrier.wait()
|
||||
except asyncio.BrokenBarrierError:
|
||||
# here catch blocked tasks (already drained)
|
||||
results2.append(True)
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertTrue(barrier.broken)
|
||||
self.assertEqual(results1, [True]*blocking_tasks)
|
||||
self.assertEqual(results2, [True]*(self.N-blocking_tasks-1))
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
self.assertNotIn("resetting", repr(barrier))
|
||||
|
||||
async def test_abort_barrier_when_exception(self):
|
||||
# test from threading.Barrier: see `lock_tests.test_reset`
|
||||
barrier = asyncio.Barrier(self.N)
|
||||
results1 = []
|
||||
results2 = []
|
||||
|
||||
async def coro():
|
||||
try:
|
||||
async with barrier as i :
|
||||
if i == self.N//2:
|
||||
raise RuntimeError
|
||||
async with barrier:
|
||||
results1.append(True)
|
||||
except asyncio.BrokenBarrierError:
|
||||
results2.append(True)
|
||||
except RuntimeError:
|
||||
await barrier.abort()
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertTrue(barrier.broken)
|
||||
self.assertEqual(len(results1), 0)
|
||||
self.assertEqual(len(results2), self.N-1)
|
||||
self.assertTrue(all(results2))
|
||||
self.assertEqual(barrier.n_waiting, 0)
|
||||
|
||||
async def test_abort_barrier_when_exception_then_resetting(self):
|
||||
# test from threading.Barrier: see `lock_tests.test_abort_and_reset``
|
||||
barrier1 = asyncio.Barrier(self.N)
|
||||
barrier2 = asyncio.Barrier(self.N)
|
||||
results1 = []
|
||||
results2 = []
|
||||
results3 = []
|
||||
|
||||
async def coro():
|
||||
try:
|
||||
i = await barrier1.wait()
|
||||
if i == self.N//2:
|
||||
raise RuntimeError
|
||||
await barrier1.wait()
|
||||
results1.append(True)
|
||||
except asyncio.BrokenBarrierError:
|
||||
results2.append(True)
|
||||
except RuntimeError:
|
||||
await barrier1.abort()
|
||||
|
||||
# Synchronize and reset the barrier. Must synchronize first so
|
||||
# that everyone has left it when we reset, and after so that no
|
||||
# one enters it before the reset.
|
||||
i = await barrier2.wait()
|
||||
if i == self.N//2:
|
||||
await barrier1.reset()
|
||||
await barrier2.wait()
|
||||
await barrier1.wait()
|
||||
results3.append(True)
|
||||
|
||||
await self.gather_tasks(self.N, coro)
|
||||
|
||||
self.assertFalse(barrier1.broken)
|
||||
self.assertEqual(len(results1), 0)
|
||||
self.assertEqual(len(results2), self.N-1)
|
||||
self.assertTrue(all(results2))
|
||||
self.assertEqual(len(results3), self.N)
|
||||
self.assertTrue(all(results3))
|
||||
|
||||
self.assertEqual(barrier1.n_waiting, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Add an Barrier object in synchronization primitives of *asyncio* Lib in order to be consistant with Barrier from *threading* and *multiprocessing* libs*
|
Loading…
Reference in New Issue