From de929f353c413459834a2a37b2d9b0240673d874 Mon Sep 17 00:00:00 2001 From: Peter Bierma Date: Thu, 26 Sep 2024 01:11:17 -0400 Subject: [PATCH] gh-124309: Modernize the `staggered_race` implementation to support eager task factories (#124390) Co-authored-by: Thomas Grainger Co-authored-by: Jelle Zijlstra Co-authored-by: Carol Willing Co-authored-by: Kumar Aditya --- Lib/asyncio/base_events.py | 2 +- Lib/asyncio/staggered.py | 77 ++++--------------- .../test_asyncio/test_eager_task_factory.py | 47 +++++++++++ Lib/test/test_asyncio/test_staggered.py | 37 ++++++++- ...-09-23-18-18-23.gh-issue-124309.iFcarA.rst | 1 + 5 files changed, 99 insertions(+), 65 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 000647f57dd..ffcc0174e1e 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -1144,7 +1144,7 @@ class BaseEventLoop(events.AbstractEventLoop): (functools.partial(self._connect_sock, exceptions, addrinfo, laddr_infos) for addrinfo in infos), - happy_eyeballs_delay, loop=self) + happy_eyeballs_delay) if sock is None: exceptions = [exc for sub in exceptions for exc in sub] diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py index c3a7441a7b0..4458d01dece 100644 --- a/Lib/asyncio/staggered.py +++ b/Lib/asyncio/staggered.py @@ -4,13 +4,14 @@ __all__ = 'staggered_race', import contextlib -from . import events -from . import exceptions as exceptions_mod from . import locks from . import tasks +from . import taskgroups +class _Done(Exception): + pass -async def staggered_race(coro_fns, delay, *, loop=None): +async def staggered_race(coro_fns, delay): """Run coroutines with staggered start times and take the first to finish. This method takes an iterable of coroutine functions. The first one is @@ -42,8 +43,6 @@ async def staggered_race(coro_fns, delay, *, loop=None): delay: amount of time, in seconds, between starting coroutines. If ``None``, the coroutines will run sequentially. - loop: the event loop to use. - Returns: tuple *(winner_result, winner_index, exceptions)* where @@ -62,36 +61,11 @@ async def staggered_race(coro_fns, delay, *, loop=None): """ # TODO: when we have aiter() and anext(), allow async iterables in coro_fns. - loop = loop or events.get_running_loop() - enum_coro_fns = enumerate(coro_fns) winner_result = None winner_index = None exceptions = [] - running_tasks = [] - - async def run_one_coro(previous_failed) -> None: - # Wait for the previous task to finish, or for delay seconds - if previous_failed is not None: - with contextlib.suppress(exceptions_mod.TimeoutError): - # Use asyncio.wait_for() instead of asyncio.wait() here, so - # that if we get cancelled at this point, Event.wait() is also - # cancelled, otherwise there will be a "Task destroyed but it is - # pending" later. - await tasks.wait_for(previous_failed.wait(), delay) - # Get the next coroutine to run - try: - this_index, coro_fn = next(enum_coro_fns) - except StopIteration: - return - # Start task that will run the next coroutine - this_failed = locks.Event() - next_task = loop.create_task(run_one_coro(this_failed)) - running_tasks.append(next_task) - assert len(running_tasks) == this_index + 2 - # Prepare place to put this coroutine's exceptions if not won - exceptions.append(None) - assert len(exceptions) == this_index + 1 + async def run_one_coro(this_index, coro_fn, this_failed): try: result = await coro_fn() except (SystemExit, KeyboardInterrupt): @@ -105,34 +79,17 @@ async def staggered_race(coro_fns, delay, *, loop=None): assert winner_index is None winner_index = this_index winner_result = result - # Cancel all other tasks. We take care to not cancel the current - # task as well. If we do so, then since there is no `await` after - # here and CancelledError are usually thrown at one, we will - # encounter a curious corner case where the current task will end - # up as done() == True, cancelled() == False, exception() == - # asyncio.CancelledError. This behavior is specified in - # https://bugs.python.org/issue30048 - for i, t in enumerate(running_tasks): - if i != this_index: - t.cancel() + raise _Done - first_task = loop.create_task(run_one_coro(None)) - running_tasks.append(first_task) try: - # Wait for a growing list of tasks to all finish: poor man's version of - # curio's TaskGroup or trio's nursery - done_count = 0 - while done_count != len(running_tasks): - done, _ = await tasks.wait(running_tasks) - done_count = len(done) - # If run_one_coro raises an unhandled exception, it's probably a - # programming error, and I want to see it. - if __debug__: - for d in done: - if d.done() and not d.cancelled() and d.exception(): - raise d.exception() - return winner_result, winner_index, exceptions - finally: - # Make sure no tasks are left running if we leave this function - for t in running_tasks: - t.cancel() + async with taskgroups.TaskGroup() as tg: + for this_index, coro_fn in enumerate(coro_fns): + this_failed = locks.Event() + exceptions.append(None) + tg.create_task(run_one_coro(this_index, coro_fn, this_failed)) + with contextlib.suppress(TimeoutError): + await tasks.wait_for(this_failed.wait(), delay) + except* _Done: + pass + + return winner_result, winner_index, exceptions diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py b/Lib/test/test_asyncio/test_eager_task_factory.py index 0777f39b572..1579ad1188d 100644 --- a/Lib/test/test_asyncio/test_eager_task_factory.py +++ b/Lib/test/test_asyncio/test_eager_task_factory.py @@ -213,6 +213,53 @@ class EagerTaskFactoryLoopTests: self.run_coro(run()) + def test_staggered_race_with_eager_tasks(self): + # See https://github.com/python/cpython/issues/124309 + + async def fail(): + await asyncio.sleep(0) + raise ValueError("no good") + + async def run(): + winner, index, excs = await asyncio.staggered.staggered_race( + [ + lambda: asyncio.sleep(2, result="sleep2"), + lambda: asyncio.sleep(1, result="sleep1"), + lambda: fail() + ], + delay=0.25 + ) + self.assertEqual(winner, 'sleep1') + self.assertEqual(index, 1) + self.assertIsNone(excs[index]) + self.assertIsInstance(excs[0], asyncio.CancelledError) + self.assertIsInstance(excs[2], ValueError) + + self.run_coro(run()) + + def test_staggered_race_with_eager_tasks_no_delay(self): + # See https://github.com/python/cpython/issues/124309 + async def fail(): + raise ValueError("no good") + + async def run(): + winner, index, excs = await asyncio.staggered.staggered_race( + [ + lambda: fail(), + lambda: asyncio.sleep(1, result="sleep1"), + lambda: asyncio.sleep(0, result="sleep0"), + ], + delay=None + ) + self.assertEqual(winner, 'sleep1') + self.assertEqual(index, 1) + self.assertIsNone(excs[index]) + self.assertIsInstance(excs[0], ValueError) + self.assertEqual(len(excs), 2) + + self.run_coro(run()) + + class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase): Task = tasks._PyTask diff --git a/Lib/test/test_asyncio/test_staggered.py b/Lib/test/test_asyncio/test_staggered.py index e6e32f7dbbb..21a39b3f911 100644 --- a/Lib/test/test_asyncio/test_staggered.py +++ b/Lib/test/test_asyncio/test_staggered.py @@ -82,16 +82,45 @@ class StaggeredTests(unittest.IsolatedAsyncioTestCase): async def coro(index): raise ValueError(index) + for delay in [None, 0, 0.1, 1]: + with self.subTest(delay=delay): + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + ], + delay=delay, + ) + + self.assertIs(winner, None) + self.assertIs(index, None) + self.assertEqual(len(excs), 2) + self.assertIsInstance(excs[0], ValueError) + self.assertIsInstance(excs[1], ValueError) + + async def test_long_delay_early_failure(self): + async def coro(index): + await asyncio.sleep(0) # Dummy coroutine for the 1 case + if index == 0: + await asyncio.sleep(0.1) # Dummy coroutine + raise ValueError(index) + + return f'Res: {index}' + winner, index, excs = await staggered_race( [ lambda: coro(0), lambda: coro(1), ], - delay=None, + delay=10, ) - self.assertIs(winner, None) - self.assertIs(index, None) + self.assertEqual(winner, 'Res: 1') + self.assertEqual(index, 1) self.assertEqual(len(excs), 2) self.assertIsInstance(excs[0], ValueError) - self.assertIsInstance(excs[1], ValueError) + self.assertIsNone(excs[1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst b/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst new file mode 100644 index 00000000000..89610fa44bf --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst @@ -0,0 +1 @@ +Fixed :exc:`AssertionError` when using :func:`!asyncio.staggered.staggered_race` with :attr:`asyncio.eager_task_factory`.