mirror of https://github.com/python/cpython
bpo-46752: Add TaskGroup; add Task..cancelled(),.uncancel() (GH-31270)
asyncio/taskgroups.py is an adaptation of taskgroup.py from EdgeDb, with the following key changes: - Allow creating new tasks as long as the last task hasn't finished - Raise [Base]ExceptionGroup (directly) rather than TaskGroupError deriving from MultiError - Instead of monkey-patching the parent task's cancel() method, add a new public API to Task The Task class has a new internal flag, `_cancel_requested`, which is set when `.cancel()` is called successfully. The `.cancelling()` method returns the value of this flag. Further `.cancel()` calls while this flag is set return False. To reset this flag, call `.uncancel()`. Thus, a Task that catches and ignores `CancelledError` should call `.uncancel()` if it wants to be cancellable again; until it does so, it is deemed to be busy with uninterruptible cleanup. This new Task API helps solve the problem where TaskGroup needs to distinguish between whether the parent task being cancelled "from the outside" vs. "from inside". Co-authored-by: Yury Selivanov <yury@edgedb.com> Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
This commit is contained in:
parent
08ec80113b
commit
602630ac18
|
@ -17,6 +17,7 @@ from .queues import *
|
|||
from .streams import *
|
||||
from .subprocess import *
|
||||
from .tasks import *
|
||||
from .taskgroups import *
|
||||
from .threads import *
|
||||
from .transports import *
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from . import coroutines
|
|||
def _task_repr_info(task):
|
||||
info = base_futures._future_repr_info(task)
|
||||
|
||||
if task._must_cancel:
|
||||
if task.cancelling() and not task.done():
|
||||
# replace status
|
||||
info[0] = 'cancelling'
|
||||
|
||||
|
|
|
@ -0,0 +1,235 @@
|
|||
# Adapted with permission from the EdgeDB project.
|
||||
|
||||
|
||||
__all__ = ["TaskGroup"]
|
||||
|
||||
import itertools
|
||||
import textwrap
|
||||
import traceback
|
||||
import types
|
||||
import weakref
|
||||
|
||||
from . import events
|
||||
from . import exceptions
|
||||
from . import tasks
|
||||
|
||||
class TaskGroup:
|
||||
|
||||
def __init__(self, *, name=None):
|
||||
if name is None:
|
||||
self._name = f'tg-{_name_counter()}'
|
||||
else:
|
||||
self._name = str(name)
|
||||
|
||||
self._entered = False
|
||||
self._exiting = False
|
||||
self._aborting = False
|
||||
self._loop = None
|
||||
self._parent_task = None
|
||||
self._parent_cancel_requested = False
|
||||
self._tasks = weakref.WeakSet()
|
||||
self._unfinished_tasks = 0
|
||||
self._errors = []
|
||||
self._base_error = None
|
||||
self._on_completed_fut = None
|
||||
|
||||
def get_name(self):
|
||||
return self._name
|
||||
|
||||
def __repr__(self):
|
||||
msg = f'<TaskGroup {self._name!r}'
|
||||
if self._tasks:
|
||||
msg += f' tasks:{len(self._tasks)}'
|
||||
if self._unfinished_tasks:
|
||||
msg += f' unfinished:{self._unfinished_tasks}'
|
||||
if self._errors:
|
||||
msg += f' errors:{len(self._errors)}'
|
||||
if self._aborting:
|
||||
msg += ' cancelling'
|
||||
elif self._entered:
|
||||
msg += ' entered'
|
||||
msg += '>'
|
||||
return msg
|
||||
|
||||
async def __aenter__(self):
|
||||
if self._entered:
|
||||
raise RuntimeError(
|
||||
f"TaskGroup {self!r} has been already entered")
|
||||
self._entered = True
|
||||
|
||||
if self._loop is None:
|
||||
self._loop = events.get_running_loop()
|
||||
|
||||
self._parent_task = tasks.current_task(self._loop)
|
||||
if self._parent_task is None:
|
||||
raise RuntimeError(
|
||||
f'TaskGroup {self!r} cannot determine the parent task')
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, et, exc, tb):
|
||||
self._exiting = True
|
||||
propagate_cancellation_error = None
|
||||
|
||||
if (exc is not None and
|
||||
self._is_base_error(exc) and
|
||||
self._base_error is None):
|
||||
self._base_error = exc
|
||||
|
||||
if et is exceptions.CancelledError:
|
||||
if self._parent_cancel_requested:
|
||||
# Only if we did request task to cancel ourselves
|
||||
# we mark it as no longer cancelled.
|
||||
self._parent_task.uncancel()
|
||||
else:
|
||||
propagate_cancellation_error = et
|
||||
|
||||
if et is not None and not self._aborting:
|
||||
# Our parent task is being cancelled:
|
||||
#
|
||||
# async with TaskGroup() as g:
|
||||
# g.create_task(...)
|
||||
# await ... # <- CancelledError
|
||||
#
|
||||
if et is exceptions.CancelledError:
|
||||
propagate_cancellation_error = et
|
||||
|
||||
# or there's an exception in "async with":
|
||||
#
|
||||
# async with TaskGroup() as g:
|
||||
# g.create_task(...)
|
||||
# 1 / 0
|
||||
#
|
||||
self._abort()
|
||||
|
||||
# We use while-loop here because "self._on_completed_fut"
|
||||
# can be cancelled multiple times if our parent task
|
||||
# is being cancelled repeatedly (or even once, when
|
||||
# our own cancellation is already in progress)
|
||||
while self._unfinished_tasks:
|
||||
if self._on_completed_fut is None:
|
||||
self._on_completed_fut = self._loop.create_future()
|
||||
|
||||
try:
|
||||
await self._on_completed_fut
|
||||
except exceptions.CancelledError as ex:
|
||||
if not self._aborting:
|
||||
# Our parent task is being cancelled:
|
||||
#
|
||||
# async def wrapper():
|
||||
# async with TaskGroup() as g:
|
||||
# g.create_task(foo)
|
||||
#
|
||||
# "wrapper" is being cancelled while "foo" is
|
||||
# still running.
|
||||
propagate_cancellation_error = ex
|
||||
self._abort()
|
||||
|
||||
self._on_completed_fut = None
|
||||
|
||||
assert self._unfinished_tasks == 0
|
||||
self._on_completed_fut = None # no longer needed
|
||||
|
||||
if self._base_error is not None:
|
||||
raise self._base_error
|
||||
|
||||
if propagate_cancellation_error is not None:
|
||||
# The wrapping task was cancelled; since we're done with
|
||||
# closing all child tasks, just propagate the cancellation
|
||||
# request now.
|
||||
raise propagate_cancellation_error
|
||||
|
||||
if et is not None and et is not exceptions.CancelledError:
|
||||
self._errors.append(exc)
|
||||
|
||||
if self._errors:
|
||||
# Exceptions are heavy objects that can have object
|
||||
# cycles (bad for GC); let's not keep a reference to
|
||||
# a bunch of them.
|
||||
errors = self._errors
|
||||
self._errors = None
|
||||
|
||||
me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors)
|
||||
raise me from None
|
||||
|
||||
def create_task(self, coro):
|
||||
if not self._entered:
|
||||
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
|
||||
if self._exiting and self._unfinished_tasks == 0:
|
||||
raise RuntimeError(f"TaskGroup {self!r} is finished")
|
||||
task = self._loop.create_task(coro)
|
||||
task.add_done_callback(self._on_task_done)
|
||||
self._unfinished_tasks += 1
|
||||
self._tasks.add(task)
|
||||
return task
|
||||
|
||||
# Since Python 3.8 Tasks propagate all exceptions correctly,
|
||||
# except for KeyboardInterrupt and SystemExit which are
|
||||
# still considered special.
|
||||
|
||||
def _is_base_error(self, exc: BaseException) -> bool:
|
||||
assert isinstance(exc, BaseException)
|
||||
return isinstance(exc, (SystemExit, KeyboardInterrupt))
|
||||
|
||||
def _abort(self):
|
||||
self._aborting = True
|
||||
|
||||
for t in self._tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
|
||||
def _on_task_done(self, task):
|
||||
self._unfinished_tasks -= 1
|
||||
assert self._unfinished_tasks >= 0
|
||||
|
||||
if self._on_completed_fut is not None and not self._unfinished_tasks:
|
||||
if not self._on_completed_fut.done():
|
||||
self._on_completed_fut.set_result(True)
|
||||
|
||||
if task.cancelled():
|
||||
return
|
||||
|
||||
exc = task.exception()
|
||||
if exc is None:
|
||||
return
|
||||
|
||||
self._errors.append(exc)
|
||||
if self._is_base_error(exc) and self._base_error is None:
|
||||
self._base_error = exc
|
||||
|
||||
if self._parent_task.done():
|
||||
# Not sure if this case is possible, but we want to handle
|
||||
# it anyways.
|
||||
self._loop.call_exception_handler({
|
||||
'message': f'Task {task!r} has errored out but its parent '
|
||||
f'task {self._parent_task} is already completed',
|
||||
'exception': exc,
|
||||
'task': task,
|
||||
})
|
||||
return
|
||||
|
||||
self._abort()
|
||||
if not self._parent_task.cancelling():
|
||||
# If parent task *is not* being cancelled, it means that we want
|
||||
# to manually cancel it to abort whatever is being run right now
|
||||
# in the TaskGroup. But we want to mark parent task as
|
||||
# "not cancelled" later in __aexit__. Example situation that
|
||||
# we need to handle:
|
||||
#
|
||||
# async def foo():
|
||||
# try:
|
||||
# async with TaskGroup() as g:
|
||||
# g.create_task(crash_soon())
|
||||
# await something # <- this needs to be canceled
|
||||
# # by the TaskGroup, e.g.
|
||||
# # foo() needs to be cancelled
|
||||
# except Exception:
|
||||
# # Ignore any exceptions raised in the TaskGroup
|
||||
# pass
|
||||
# await something_else # this line has to be called
|
||||
# # after TaskGroup is finished.
|
||||
self._parent_cancel_requested = True
|
||||
self._parent_task.cancel()
|
||||
|
||||
|
||||
_name_counter = itertools.count(1).__next__
|
|
@ -105,6 +105,7 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
|
|||
else:
|
||||
self._name = str(name)
|
||||
|
||||
self._cancel_requested = False
|
||||
self._must_cancel = False
|
||||
self._fut_waiter = None
|
||||
self._coro = coro
|
||||
|
@ -201,6 +202,9 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
|
|||
self._log_traceback = False
|
||||
if self.done():
|
||||
return False
|
||||
if self._cancel_requested:
|
||||
return False
|
||||
self._cancel_requested = True
|
||||
if self._fut_waiter is not None:
|
||||
if self._fut_waiter.cancel(msg=msg):
|
||||
# Leave self._fut_waiter; it may be a Task that
|
||||
|
@ -212,6 +216,16 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
|
|||
self._cancel_message = msg
|
||||
return True
|
||||
|
||||
def cancelling(self):
|
||||
return self._cancel_requested
|
||||
|
||||
def uncancel(self):
|
||||
if self._cancel_requested:
|
||||
self._cancel_requested = False
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def __step(self, exc=None):
|
||||
if self.done():
|
||||
raise exceptions.InvalidStateError(
|
||||
|
@ -634,7 +648,7 @@ def _ensure_future(coro_or_future, *, loop=None):
|
|||
loop = events._get_event_loop(stacklevel=4)
|
||||
try:
|
||||
return loop.create_task(coro_or_future)
|
||||
except RuntimeError:
|
||||
except RuntimeError:
|
||||
if not called_wrap_awaitable:
|
||||
coro_or_future.close()
|
||||
raise
|
||||
|
|
|
@ -0,0 +1,694 @@
|
|||
# Adapted with permission from the EdgeDB project.
|
||||
|
||||
|
||||
import asyncio
|
||||
|
||||
from asyncio import taskgroups
|
||||
import unittest
|
||||
|
||||
|
||||
# To prevent a warning "test altered the execution environment"
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
class MyExc(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MyBaseExc(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
def get_error_types(eg):
|
||||
return {type(exc) for exc in eg.exceptions}
|
||||
|
||||
|
||||
class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_taskgroup_01(self):
|
||||
|
||||
async def foo1():
|
||||
await asyncio.sleep(0.1)
|
||||
return 42
|
||||
|
||||
async def foo2():
|
||||
await asyncio.sleep(0.2)
|
||||
return 11
|
||||
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
t1 = g.create_task(foo1())
|
||||
t2 = g.create_task(foo2())
|
||||
|
||||
self.assertEqual(t1.result(), 42)
|
||||
self.assertEqual(t2.result(), 11)
|
||||
|
||||
async def test_taskgroup_02(self):
|
||||
|
||||
async def foo1():
|
||||
await asyncio.sleep(0.1)
|
||||
return 42
|
||||
|
||||
async def foo2():
|
||||
await asyncio.sleep(0.2)
|
||||
return 11
|
||||
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
t1 = g.create_task(foo1())
|
||||
await asyncio.sleep(0.15)
|
||||
t2 = g.create_task(foo2())
|
||||
|
||||
self.assertEqual(t1.result(), 42)
|
||||
self.assertEqual(t2.result(), 11)
|
||||
|
||||
async def test_taskgroup_03(self):
|
||||
|
||||
async def foo1():
|
||||
await asyncio.sleep(1)
|
||||
return 42
|
||||
|
||||
async def foo2():
|
||||
await asyncio.sleep(0.2)
|
||||
return 11
|
||||
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
t1 = g.create_task(foo1())
|
||||
await asyncio.sleep(0.15)
|
||||
# cancel t1 explicitly, i.e. everything should continue
|
||||
# working as expected.
|
||||
t1.cancel()
|
||||
|
||||
t2 = g.create_task(foo2())
|
||||
|
||||
self.assertTrue(t1.cancelled())
|
||||
self.assertEqual(t2.result(), 11)
|
||||
|
||||
async def test_taskgroup_04(self):
|
||||
|
||||
NUM = 0
|
||||
t2_cancel = False
|
||||
t2 = None
|
||||
|
||||
async def foo1():
|
||||
await asyncio.sleep(0.1)
|
||||
1 / 0
|
||||
|
||||
async def foo2():
|
||||
nonlocal NUM, t2_cancel
|
||||
try:
|
||||
await asyncio.sleep(1)
|
||||
except asyncio.CancelledError:
|
||||
t2_cancel = True
|
||||
raise
|
||||
NUM += 1
|
||||
|
||||
async def runner():
|
||||
nonlocal NUM, t2
|
||||
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
g.create_task(foo1())
|
||||
t2 = g.create_task(foo2())
|
||||
|
||||
NUM += 10
|
||||
|
||||
with self.assertRaises(ExceptionGroup) as cm:
|
||||
await asyncio.create_task(runner())
|
||||
|
||||
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
|
||||
|
||||
self.assertEqual(NUM, 0)
|
||||
self.assertTrue(t2_cancel)
|
||||
self.assertTrue(t2.cancelled())
|
||||
|
||||
async def test_taskgroup_05(self):
|
||||
|
||||
NUM = 0
|
||||
t2_cancel = False
|
||||
runner_cancel = False
|
||||
|
||||
async def foo1():
|
||||
await asyncio.sleep(0.1)
|
||||
1 / 0
|
||||
|
||||
async def foo2():
|
||||
nonlocal NUM, t2_cancel
|
||||
try:
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
t2_cancel = True
|
||||
raise
|
||||
NUM += 1
|
||||
|
||||
async def runner():
|
||||
nonlocal NUM, runner_cancel
|
||||
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
g.create_task(foo1())
|
||||
g.create_task(foo1())
|
||||
g.create_task(foo1())
|
||||
g.create_task(foo2())
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
runner_cancel = True
|
||||
raise
|
||||
|
||||
NUM += 10
|
||||
|
||||
# The 3 foo1 sub tasks can be racy when the host is busy - if the
|
||||
# cancellation happens in the middle, we'll see partial sub errors here
|
||||
with self.assertRaises(ExceptionGroup) as cm:
|
||||
await asyncio.create_task(runner())
|
||||
|
||||
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
|
||||
self.assertEqual(NUM, 0)
|
||||
self.assertTrue(t2_cancel)
|
||||
self.assertTrue(runner_cancel)
|
||||
|
||||
async def test_taskgroup_06(self):
|
||||
|
||||
NUM = 0
|
||||
|
||||
async def foo():
|
||||
nonlocal NUM
|
||||
try:
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
NUM += 1
|
||||
raise
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
for _ in range(5):
|
||||
g.create_task(foo())
|
||||
|
||||
r = asyncio.create_task(runner())
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
self.assertFalse(r.done())
|
||||
r.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await r
|
||||
|
||||
self.assertEqual(NUM, 5)
|
||||
|
||||
async def test_taskgroup_07(self):
|
||||
|
||||
NUM = 0
|
||||
|
||||
async def foo():
|
||||
nonlocal NUM
|
||||
try:
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
NUM += 1
|
||||
raise
|
||||
|
||||
async def runner():
|
||||
nonlocal NUM
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
for _ in range(5):
|
||||
g.create_task(foo())
|
||||
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
NUM += 10
|
||||
raise
|
||||
|
||||
r = asyncio.create_task(runner())
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
self.assertFalse(r.done())
|
||||
r.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await r
|
||||
|
||||
self.assertEqual(NUM, 15)
|
||||
|
||||
async def test_taskgroup_08(self):
|
||||
|
||||
async def foo():
|
||||
await asyncio.sleep(0.1)
|
||||
1 / 0
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
for _ in range(5):
|
||||
g.create_task(foo())
|
||||
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
r = asyncio.create_task(runner())
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
self.assertFalse(r.done())
|
||||
r.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await r
|
||||
|
||||
async def test_taskgroup_09(self):
|
||||
|
||||
t1 = t2 = None
|
||||
|
||||
async def foo1():
|
||||
await asyncio.sleep(1)
|
||||
return 42
|
||||
|
||||
async def foo2():
|
||||
await asyncio.sleep(2)
|
||||
return 11
|
||||
|
||||
async def runner():
|
||||
nonlocal t1, t2
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
t1 = g.create_task(foo1())
|
||||
t2 = g.create_task(foo2())
|
||||
await asyncio.sleep(0.1)
|
||||
1 / 0
|
||||
|
||||
try:
|
||||
await runner()
|
||||
except ExceptionGroup as t:
|
||||
self.assertEqual(get_error_types(t), {ZeroDivisionError})
|
||||
else:
|
||||
self.fail('ExceptionGroup was not raised')
|
||||
|
||||
self.assertTrue(t1.cancelled())
|
||||
self.assertTrue(t2.cancelled())
|
||||
|
||||
async def test_taskgroup_10(self):
|
||||
|
||||
t1 = t2 = None
|
||||
|
||||
async def foo1():
|
||||
await asyncio.sleep(1)
|
||||
return 42
|
||||
|
||||
async def foo2():
|
||||
await asyncio.sleep(2)
|
||||
return 11
|
||||
|
||||
async def runner():
|
||||
nonlocal t1, t2
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
t1 = g.create_task(foo1())
|
||||
t2 = g.create_task(foo2())
|
||||
1 / 0
|
||||
|
||||
try:
|
||||
await runner()
|
||||
except ExceptionGroup as t:
|
||||
self.assertEqual(get_error_types(t), {ZeroDivisionError})
|
||||
else:
|
||||
self.fail('ExceptionGroup was not raised')
|
||||
|
||||
self.assertTrue(t1.cancelled())
|
||||
self.assertTrue(t2.cancelled())
|
||||
|
||||
async def test_taskgroup_11(self):
|
||||
|
||||
async def foo():
|
||||
await asyncio.sleep(0.1)
|
||||
1 / 0
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup():
|
||||
async with taskgroups.TaskGroup() as g2:
|
||||
for _ in range(5):
|
||||
g2.create_task(foo())
|
||||
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
r = asyncio.create_task(runner())
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
self.assertFalse(r.done())
|
||||
r.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await r
|
||||
|
||||
async def test_taskgroup_12(self):
|
||||
|
||||
async def foo():
|
||||
await asyncio.sleep(0.1)
|
||||
1 / 0
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup() as g1:
|
||||
g1.create_task(asyncio.sleep(10))
|
||||
|
||||
async with taskgroups.TaskGroup() as g2:
|
||||
for _ in range(5):
|
||||
g2.create_task(foo())
|
||||
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
r = asyncio.create_task(runner())
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
self.assertFalse(r.done())
|
||||
r.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await r
|
||||
|
||||
async def test_taskgroup_13(self):
|
||||
|
||||
async def crash_after(t):
|
||||
await asyncio.sleep(t)
|
||||
raise ValueError(t)
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup(name='g1') as g1:
|
||||
g1.create_task(crash_after(0.1))
|
||||
|
||||
async with taskgroups.TaskGroup(name='g2') as g2:
|
||||
g2.create_task(crash_after(0.2))
|
||||
|
||||
r = asyncio.create_task(runner())
|
||||
with self.assertRaises(ExceptionGroup) as cm:
|
||||
await r
|
||||
|
||||
self.assertEqual(get_error_types(cm.exception), {ValueError})
|
||||
|
||||
async def test_taskgroup_14(self):
|
||||
|
||||
async def crash_after(t):
|
||||
await asyncio.sleep(t)
|
||||
raise ValueError(t)
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup(name='g1') as g1:
|
||||
g1.create_task(crash_after(10))
|
||||
|
||||
async with taskgroups.TaskGroup(name='g2') as g2:
|
||||
g2.create_task(crash_after(0.1))
|
||||
|
||||
r = asyncio.create_task(runner())
|
||||
with self.assertRaises(ExceptionGroup) as cm:
|
||||
await r
|
||||
|
||||
self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
|
||||
self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ValueError})
|
||||
|
||||
async def test_taskgroup_15(self):
|
||||
|
||||
async def crash_soon():
|
||||
await asyncio.sleep(0.3)
|
||||
1 / 0
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup(name='g1') as g1:
|
||||
g1.create_task(crash_soon())
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
await asyncio.sleep(0.5)
|
||||
raise
|
||||
|
||||
r = asyncio.create_task(runner())
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
self.assertFalse(r.done())
|
||||
r.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await r
|
||||
|
||||
async def test_taskgroup_16(self):
|
||||
|
||||
async def crash_soon():
|
||||
await asyncio.sleep(0.3)
|
||||
1 / 0
|
||||
|
||||
async def nested_runner():
|
||||
async with taskgroups.TaskGroup(name='g1') as g1:
|
||||
g1.create_task(crash_soon())
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
await asyncio.sleep(0.5)
|
||||
raise
|
||||
|
||||
async def runner():
|
||||
t = asyncio.create_task(nested_runner())
|
||||
await t
|
||||
|
||||
r = asyncio.create_task(runner())
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
self.assertFalse(r.done())
|
||||
r.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await r
|
||||
|
||||
async def test_taskgroup_17(self):
|
||||
NUM = 0
|
||||
|
||||
async def runner():
|
||||
nonlocal NUM
|
||||
async with taskgroups.TaskGroup():
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
NUM += 10
|
||||
raise
|
||||
|
||||
r = asyncio.create_task(runner())
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
self.assertFalse(r.done())
|
||||
r.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await r
|
||||
|
||||
self.assertEqual(NUM, 10)
|
||||
|
||||
async def test_taskgroup_18(self):
|
||||
NUM = 0
|
||||
|
||||
async def runner():
|
||||
nonlocal NUM
|
||||
async with taskgroups.TaskGroup():
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
NUM += 10
|
||||
# This isn't a good idea, but we have to support
|
||||
# this weird case.
|
||||
raise MyExc
|
||||
|
||||
r = asyncio.create_task(runner())
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
self.assertFalse(r.done())
|
||||
r.cancel()
|
||||
|
||||
try:
|
||||
await r
|
||||
except ExceptionGroup as t:
|
||||
self.assertEqual(get_error_types(t),{MyExc})
|
||||
else:
|
||||
self.fail('ExceptionGroup was not raised')
|
||||
|
||||
self.assertEqual(NUM, 10)
|
||||
|
||||
async def test_taskgroup_19(self):
|
||||
async def crash_soon():
|
||||
await asyncio.sleep(0.1)
|
||||
1 / 0
|
||||
|
||||
async def nested():
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
finally:
|
||||
raise MyExc
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
g.create_task(crash_soon())
|
||||
await nested()
|
||||
|
||||
r = asyncio.create_task(runner())
|
||||
try:
|
||||
await r
|
||||
except ExceptionGroup as t:
|
||||
self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError})
|
||||
else:
|
||||
self.fail('TasgGroupError was not raised')
|
||||
|
||||
async def test_taskgroup_20(self):
|
||||
async def crash_soon():
|
||||
await asyncio.sleep(0.1)
|
||||
1 / 0
|
||||
|
||||
async def nested():
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
finally:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
g.create_task(crash_soon())
|
||||
await nested()
|
||||
|
||||
with self.assertRaises(KeyboardInterrupt):
|
||||
await runner()
|
||||
|
||||
async def test_taskgroup_20a(self):
|
||||
async def crash_soon():
|
||||
await asyncio.sleep(0.1)
|
||||
1 / 0
|
||||
|
||||
async def nested():
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
finally:
|
||||
raise MyBaseExc
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
g.create_task(crash_soon())
|
||||
await nested()
|
||||
|
||||
with self.assertRaises(BaseExceptionGroup) as cm:
|
||||
await runner()
|
||||
|
||||
self.assertEqual(
|
||||
get_error_types(cm.exception), {MyBaseExc, ZeroDivisionError}
|
||||
)
|
||||
|
||||
async def _test_taskgroup_21(self):
|
||||
# This test doesn't work as asyncio, currently, doesn't
|
||||
# correctly propagate KeyboardInterrupt (or SystemExit) --
|
||||
# those cause the event loop itself to crash.
|
||||
# (Compare to the previous (passing) test -- that one raises
|
||||
# a plain exception but raises KeyboardInterrupt in nested();
|
||||
# this test does it the other way around.)
|
||||
|
||||
async def crash_soon():
|
||||
await asyncio.sleep(0.1)
|
||||
raise KeyboardInterrupt
|
||||
|
||||
async def nested():
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
finally:
|
||||
raise TypeError
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
g.create_task(crash_soon())
|
||||
await nested()
|
||||
|
||||
with self.assertRaises(KeyboardInterrupt):
|
||||
await runner()
|
||||
|
||||
async def test_taskgroup_21a(self):
|
||||
|
||||
async def crash_soon():
|
||||
await asyncio.sleep(0.1)
|
||||
raise MyBaseExc
|
||||
|
||||
async def nested():
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
finally:
|
||||
raise TypeError
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
g.create_task(crash_soon())
|
||||
await nested()
|
||||
|
||||
with self.assertRaises(BaseExceptionGroup) as cm:
|
||||
await runner()
|
||||
|
||||
self.assertEqual(get_error_types(cm.exception), {MyBaseExc, TypeError})
|
||||
|
||||
async def test_taskgroup_22(self):
|
||||
|
||||
async def foo1():
|
||||
await asyncio.sleep(1)
|
||||
return 42
|
||||
|
||||
async def foo2():
|
||||
await asyncio.sleep(2)
|
||||
return 11
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
g.create_task(foo1())
|
||||
g.create_task(foo2())
|
||||
|
||||
r = asyncio.create_task(runner())
|
||||
await asyncio.sleep(0.05)
|
||||
r.cancel()
|
||||
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await r
|
||||
|
||||
async def test_taskgroup_23(self):
|
||||
|
||||
async def do_job(delay):
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
for count in range(10):
|
||||
await asyncio.sleep(0.1)
|
||||
g.create_task(do_job(0.3))
|
||||
if count == 5:
|
||||
self.assertLess(len(g._tasks), 5)
|
||||
await asyncio.sleep(1.35)
|
||||
self.assertEqual(len(g._tasks), 0)
|
||||
|
||||
async def test_taskgroup_24(self):
|
||||
|
||||
async def root(g):
|
||||
await asyncio.sleep(0.1)
|
||||
g.create_task(coro1(0.1))
|
||||
g.create_task(coro1(0.2))
|
||||
|
||||
async def coro1(delay):
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
g.create_task(root(g))
|
||||
|
||||
await runner()
|
||||
|
||||
async def test_taskgroup_25(self):
|
||||
nhydras = 0
|
||||
|
||||
async def hydra(g):
|
||||
nonlocal nhydras
|
||||
nhydras += 1
|
||||
await asyncio.sleep(0.01)
|
||||
g.create_task(hydra(g))
|
||||
g.create_task(hydra(g))
|
||||
|
||||
async def hercules():
|
||||
while nhydras < 10:
|
||||
await asyncio.sleep(0.015)
|
||||
1 / 0
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
g.create_task(hydra(g))
|
||||
g.create_task(hercules())
|
||||
|
||||
with self.assertRaises(ExceptionGroup) as cm:
|
||||
await runner()
|
||||
|
||||
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
|
||||
self.assertGreaterEqual(nhydras, 10)
|
|
@ -496,6 +496,51 @@ class BaseTaskTests:
|
|||
# This also distinguishes from the initial has_cycle=None.
|
||||
self.assertEqual(has_cycle, False)
|
||||
|
||||
|
||||
def test_cancelling(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
async def task():
|
||||
await asyncio.sleep(10)
|
||||
|
||||
try:
|
||||
t = self.new_task(loop, task())
|
||||
self.assertFalse(t.cancelling())
|
||||
self.assertNotIn(" cancelling ", repr(t))
|
||||
self.assertTrue(t.cancel())
|
||||
self.assertTrue(t.cancelling())
|
||||
self.assertIn(" cancelling ", repr(t))
|
||||
self.assertFalse(t.cancel())
|
||||
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
loop.run_until_complete(t)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def test_uncancel(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
async def task():
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
asyncio.current_task().uncancel()
|
||||
await asyncio.sleep(10)
|
||||
|
||||
try:
|
||||
t = self.new_task(loop, task())
|
||||
loop.run_until_complete(asyncio.sleep(0.01))
|
||||
self.assertTrue(t.cancel()) # Cancel first sleep
|
||||
self.assertIn(" cancelling ", repr(t))
|
||||
loop.run_until_complete(asyncio.sleep(0.01))
|
||||
self.assertNotIn(" cancelling ", repr(t)) # after .uncancel()
|
||||
self.assertTrue(t.cancel()) # Cancel second sleep
|
||||
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
loop.run_until_complete(t)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def test_cancel(self):
|
||||
|
||||
def gen():
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
Add task groups to asyncio (structured concurrency, inspired by Trio's nurseries).
|
||||
This also introduces a change to task cancellation, where a cancelled task can't be cancelled again until it calls .uncancel().
|
|
@ -91,6 +91,7 @@ typedef struct {
|
|||
PyObject *task_context;
|
||||
int task_must_cancel;
|
||||
int task_log_destroy_pending;
|
||||
int task_cancel_requested;
|
||||
} TaskObj;
|
||||
|
||||
typedef struct {
|
||||
|
@ -2039,6 +2040,7 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop,
|
|||
Py_CLEAR(self->task_fut_waiter);
|
||||
self->task_must_cancel = 0;
|
||||
self->task_log_destroy_pending = 1;
|
||||
self->task_cancel_requested = 0;
|
||||
Py_INCREF(coro);
|
||||
Py_XSETREF(self->task_coro, coro);
|
||||
|
||||
|
@ -2205,6 +2207,11 @@ _asyncio_Task_cancel_impl(TaskObj *self, PyObject *msg)
|
|||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
if (self->task_cancel_requested) {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
self->task_cancel_requested = 1;
|
||||
|
||||
if (self->task_fut_waiter) {
|
||||
PyObject *res;
|
||||
int is_true;
|
||||
|
@ -2232,6 +2239,56 @@ _asyncio_Task_cancel_impl(TaskObj *self, PyObject *msg)
|
|||
Py_RETURN_TRUE;
|
||||
}
|
||||
|
||||
/*[clinic input]
|
||||
_asyncio.Task.cancelling
|
||||
|
||||
Return True if the task is in the process of being cancelled.
|
||||
|
||||
This is set once .cancel() is called
|
||||
and remains set until .uncancel() is called.
|
||||
|
||||
As long as this flag is set, further .cancel() calls will be ignored,
|
||||
until .uncancel() is called to reset it.
|
||||
[clinic start generated code]*/
|
||||
|
||||
static PyObject *
|
||||
_asyncio_Task_cancelling_impl(TaskObj *self)
|
||||
/*[clinic end generated code: output=803b3af96f917d7e input=c50e50f9c3ca4676]*/
|
||||
/*[clinic end generated code]*/
|
||||
{
|
||||
if (self->task_cancel_requested) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
}
|
||||
|
||||
/*[clinic input]
|
||||
_asyncio.Task.uncancel
|
||||
|
||||
Reset the flag returned by cancelling().
|
||||
|
||||
This should be used by tasks that catch CancelledError
|
||||
and wish to continue indefinitely until they are cancelled again.
|
||||
|
||||
Returns the previous value of the flag.
|
||||
[clinic start generated code]*/
|
||||
|
||||
static PyObject *
|
||||
_asyncio_Task_uncancel_impl(TaskObj *self)
|
||||
/*[clinic end generated code: output=58184d236a817d3c input=5db95e28fcb6f7cd]*/
|
||||
/*[clinic end generated code]*/
|
||||
{
|
||||
if (self->task_cancel_requested) {
|
||||
self->task_cancel_requested = 0;
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
}
|
||||
|
||||
/*[clinic input]
|
||||
_asyncio.Task.get_stack
|
||||
|
||||
|
@ -2455,6 +2512,8 @@ static PyMethodDef TaskType_methods[] = {
|
|||
_ASYNCIO_TASK_SET_RESULT_METHODDEF
|
||||
_ASYNCIO_TASK_SET_EXCEPTION_METHODDEF
|
||||
_ASYNCIO_TASK_CANCEL_METHODDEF
|
||||
_ASYNCIO_TASK_CANCELLING_METHODDEF
|
||||
_ASYNCIO_TASK_UNCANCEL_METHODDEF
|
||||
_ASYNCIO_TASK_GET_STACK_METHODDEF
|
||||
_ASYNCIO_TASK_PRINT_STACK_METHODDEF
|
||||
_ASYNCIO_TASK__MAKE_CANCELLED_ERROR_METHODDEF
|
||||
|
|
|
@ -447,6 +447,53 @@ exit:
|
|||
return return_value;
|
||||
}
|
||||
|
||||
PyDoc_STRVAR(_asyncio_Task_cancelling__doc__,
|
||||
"cancelling($self, /)\n"
|
||||
"--\n"
|
||||
"\n"
|
||||
"Return True if the task is in the process of being cancelled.\n"
|
||||
"\n"
|
||||
"This is set once .cancel() is called\n"
|
||||
"and remains set until .uncancel() is called.\n"
|
||||
"\n"
|
||||
"As long as this flag is set, further .cancel() calls will be ignored,\n"
|
||||
"until .uncancel() is called to reset it.");
|
||||
|
||||
#define _ASYNCIO_TASK_CANCELLING_METHODDEF \
|
||||
{"cancelling", (PyCFunction)_asyncio_Task_cancelling, METH_NOARGS, _asyncio_Task_cancelling__doc__},
|
||||
|
||||
static PyObject *
|
||||
_asyncio_Task_cancelling_impl(TaskObj *self);
|
||||
|
||||
static PyObject *
|
||||
_asyncio_Task_cancelling(TaskObj *self, PyObject *Py_UNUSED(ignored))
|
||||
{
|
||||
return _asyncio_Task_cancelling_impl(self);
|
||||
}
|
||||
|
||||
PyDoc_STRVAR(_asyncio_Task_uncancel__doc__,
|
||||
"uncancel($self, /)\n"
|
||||
"--\n"
|
||||
"\n"
|
||||
"Reset the flag returned by cancelling().\n"
|
||||
"\n"
|
||||
"This should be used by tasks that catch CancelledError\n"
|
||||
"and wish to continue indefinitely until they are cancelled again.\n"
|
||||
"\n"
|
||||
"Returns the previous value of the flag.");
|
||||
|
||||
#define _ASYNCIO_TASK_UNCANCEL_METHODDEF \
|
||||
{"uncancel", (PyCFunction)_asyncio_Task_uncancel, METH_NOARGS, _asyncio_Task_uncancel__doc__},
|
||||
|
||||
static PyObject *
|
||||
_asyncio_Task_uncancel_impl(TaskObj *self);
|
||||
|
||||
static PyObject *
|
||||
_asyncio_Task_uncancel(TaskObj *self, PyObject *Py_UNUSED(ignored))
|
||||
{
|
||||
return _asyncio_Task_uncancel_impl(self);
|
||||
}
|
||||
|
||||
PyDoc_STRVAR(_asyncio_Task_get_stack__doc__,
|
||||
"get_stack($self, /, *, limit=None)\n"
|
||||
"--\n"
|
||||
|
@ -871,4 +918,4 @@ _asyncio__leave_task(PyObject *module, PyObject *const *args, Py_ssize_t nargs,
|
|||
exit:
|
||||
return return_value;
|
||||
}
|
||||
/*[clinic end generated code: output=0d127162ac92e0c0 input=a9049054013a1b77]*/
|
||||
/*[clinic end generated code: output=c02708a9d6a774cc input=a9049054013a1b77]*/
|
||||
|
|
Loading…
Reference in New Issue