gh-124309: Modernize the `staggered_race` implementation to support eager task factories (#124390)

Co-authored-by: Thomas Grainger <tagrain@gmail.com>
Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
Co-authored-by: Carol Willing <carolcode@willingconsulting.com>
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
This commit is contained in:
Peter Bierma 2024-09-26 01:11:17 -04:00 committed by GitHub
parent d9296529eb
commit de929f353c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 99 additions and 65 deletions

View File

@ -1144,7 +1144,7 @@ class BaseEventLoop(events.AbstractEventLoop):
(functools.partial(self._connect_sock, (functools.partial(self._connect_sock,
exceptions, addrinfo, laddr_infos) exceptions, addrinfo, laddr_infos)
for addrinfo in infos), for addrinfo in infos),
happy_eyeballs_delay, loop=self) happy_eyeballs_delay)
if sock is None: if sock is None:
exceptions = [exc for sub in exceptions for exc in sub] exceptions = [exc for sub in exceptions for exc in sub]

View File

@ -4,13 +4,14 @@ __all__ = 'staggered_race',
import contextlib import contextlib
from . import events
from . import exceptions as exceptions_mod
from . import locks from . import locks
from . import tasks from . import tasks
from . import taskgroups
class _Done(Exception):
pass
async def staggered_race(coro_fns, delay, *, loop=None): async def staggered_race(coro_fns, delay):
"""Run coroutines with staggered start times and take the first to finish. """Run coroutines with staggered start times and take the first to finish.
This method takes an iterable of coroutine functions. The first one is This method takes an iterable of coroutine functions. The first one is
@ -42,8 +43,6 @@ async def staggered_race(coro_fns, delay, *, loop=None):
delay: amount of time, in seconds, between starting coroutines. If delay: amount of time, in seconds, between starting coroutines. If
``None``, the coroutines will run sequentially. ``None``, the coroutines will run sequentially.
loop: the event loop to use.
Returns: Returns:
tuple *(winner_result, winner_index, exceptions)* where tuple *(winner_result, winner_index, exceptions)* where
@ -62,36 +61,11 @@ async def staggered_race(coro_fns, delay, *, loop=None):
""" """
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns. # TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
loop = loop or events.get_running_loop()
enum_coro_fns = enumerate(coro_fns)
winner_result = None winner_result = None
winner_index = None winner_index = None
exceptions = [] exceptions = []
running_tasks = []
async def run_one_coro(previous_failed) -> None:
# Wait for the previous task to finish, or for delay seconds
if previous_failed is not None:
with contextlib.suppress(exceptions_mod.TimeoutError):
# Use asyncio.wait_for() instead of asyncio.wait() here, so
# that if we get cancelled at this point, Event.wait() is also
# cancelled, otherwise there will be a "Task destroyed but it is
# pending" later.
await tasks.wait_for(previous_failed.wait(), delay)
# Get the next coroutine to run
try:
this_index, coro_fn = next(enum_coro_fns)
except StopIteration:
return
# Start task that will run the next coroutine
this_failed = locks.Event()
next_task = loop.create_task(run_one_coro(this_failed))
running_tasks.append(next_task)
assert len(running_tasks) == this_index + 2
# Prepare place to put this coroutine's exceptions if not won
exceptions.append(None)
assert len(exceptions) == this_index + 1
async def run_one_coro(this_index, coro_fn, this_failed):
try: try:
result = await coro_fn() result = await coro_fn()
except (SystemExit, KeyboardInterrupt): except (SystemExit, KeyboardInterrupt):
@ -105,34 +79,17 @@ async def staggered_race(coro_fns, delay, *, loop=None):
assert winner_index is None assert winner_index is None
winner_index = this_index winner_index = this_index
winner_result = result winner_result = result
# Cancel all other tasks. We take care to not cancel the current raise _Done
# task as well. If we do so, then since there is no `await` after
# here and CancelledError are usually thrown at one, we will
# encounter a curious corner case where the current task will end
# up as done() == True, cancelled() == False, exception() ==
# asyncio.CancelledError. This behavior is specified in
# https://bugs.python.org/issue30048
for i, t in enumerate(running_tasks):
if i != this_index:
t.cancel()
first_task = loop.create_task(run_one_coro(None))
running_tasks.append(first_task)
try: try:
# Wait for a growing list of tasks to all finish: poor man's version of async with taskgroups.TaskGroup() as tg:
# curio's TaskGroup or trio's nursery for this_index, coro_fn in enumerate(coro_fns):
done_count = 0 this_failed = locks.Event()
while done_count != len(running_tasks): exceptions.append(None)
done, _ = await tasks.wait(running_tasks) tg.create_task(run_one_coro(this_index, coro_fn, this_failed))
done_count = len(done) with contextlib.suppress(TimeoutError):
# If run_one_coro raises an unhandled exception, it's probably a await tasks.wait_for(this_failed.wait(), delay)
# programming error, and I want to see it. except* _Done:
if __debug__: pass
for d in done:
if d.done() and not d.cancelled() and d.exception():
raise d.exception()
return winner_result, winner_index, exceptions return winner_result, winner_index, exceptions
finally:
# Make sure no tasks are left running if we leave this function
for t in running_tasks:
t.cancel()

View File

@ -213,6 +213,53 @@ class EagerTaskFactoryLoopTests:
self.run_coro(run()) self.run_coro(run())
def test_staggered_race_with_eager_tasks(self):
# See https://github.com/python/cpython/issues/124309
async def fail():
await asyncio.sleep(0)
raise ValueError("no good")
async def run():
winner, index, excs = await asyncio.staggered.staggered_race(
[
lambda: asyncio.sleep(2, result="sleep2"),
lambda: asyncio.sleep(1, result="sleep1"),
lambda: fail()
],
delay=0.25
)
self.assertEqual(winner, 'sleep1')
self.assertEqual(index, 1)
self.assertIsNone(excs[index])
self.assertIsInstance(excs[0], asyncio.CancelledError)
self.assertIsInstance(excs[2], ValueError)
self.run_coro(run())
def test_staggered_race_with_eager_tasks_no_delay(self):
# See https://github.com/python/cpython/issues/124309
async def fail():
raise ValueError("no good")
async def run():
winner, index, excs = await asyncio.staggered.staggered_race(
[
lambda: fail(),
lambda: asyncio.sleep(1, result="sleep1"),
lambda: asyncio.sleep(0, result="sleep0"),
],
delay=None
)
self.assertEqual(winner, 'sleep1')
self.assertEqual(index, 1)
self.assertIsNone(excs[index])
self.assertIsInstance(excs[0], ValueError)
self.assertEqual(len(excs), 2)
self.run_coro(run())
class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase): class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
Task = tasks._PyTask Task = tasks._PyTask

View File

@ -82,12 +82,14 @@ class StaggeredTests(unittest.IsolatedAsyncioTestCase):
async def coro(index): async def coro(index):
raise ValueError(index) raise ValueError(index)
for delay in [None, 0, 0.1, 1]:
with self.subTest(delay=delay):
winner, index, excs = await staggered_race( winner, index, excs = await staggered_race(
[ [
lambda: coro(0), lambda: coro(0),
lambda: coro(1), lambda: coro(1),
], ],
delay=None, delay=delay,
) )
self.assertIs(winner, None) self.assertIs(winner, None)
@ -95,3 +97,30 @@ class StaggeredTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(len(excs), 2) self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], ValueError) self.assertIsInstance(excs[0], ValueError)
self.assertIsInstance(excs[1], ValueError) self.assertIsInstance(excs[1], ValueError)
async def test_long_delay_early_failure(self):
async def coro(index):
await asyncio.sleep(0) # Dummy coroutine for the 1 case
if index == 0:
await asyncio.sleep(0.1) # Dummy coroutine
raise ValueError(index)
return f'Res: {index}'
winner, index, excs = await staggered_race(
[
lambda: coro(0),
lambda: coro(1),
],
delay=10,
)
self.assertEqual(winner, 'Res: 1')
self.assertEqual(index, 1)
self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], ValueError)
self.assertIsNone(excs[1])
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1 @@
Fixed :exc:`AssertionError` when using :func:`!asyncio.staggered.staggered_race` with :attr:`asyncio.eager_task_factory`.