mirror of https://github.com/python/cpython
1003 lines
28 KiB
Python
1003 lines
28 KiB
Python
# Adapted with permission from the EdgeDB project;
|
|
# license: PSFL.
|
|
|
|
import sys
|
|
import gc
|
|
import asyncio
|
|
import contextvars
|
|
import contextlib
|
|
from asyncio import taskgroups
|
|
import unittest
|
|
import warnings
|
|
|
|
from test.test_asyncio.utils import await_without_task
|
|
|
|
# To prevent a warning "test altered the execution environment"
|
|
def tearDownModule():
|
|
asyncio.set_event_loop_policy(None)
|
|
|
|
|
|
class MyExc(Exception):
|
|
pass
|
|
|
|
|
|
class MyBaseExc(BaseException):
|
|
pass
|
|
|
|
|
|
def get_error_types(eg):
|
|
return {type(exc) for exc in eg.exceptions}
|
|
|
|
|
|
def no_other_refs():
|
|
# due to gh-124392 coroutines now refer to their locals
|
|
coro = asyncio.current_task().get_coro()
|
|
frame = sys._getframe(1)
|
|
while coro.cr_frame != frame:
|
|
coro = coro.cr_await
|
|
return [coro]
|
|
|
|
|
|
class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
|
|
|
|
async def test_taskgroup_01(self):
|
|
|
|
async def foo1():
|
|
await asyncio.sleep(0.1)
|
|
return 42
|
|
|
|
async def foo2():
|
|
await asyncio.sleep(0.2)
|
|
return 11
|
|
|
|
async with taskgroups.TaskGroup() as g:
|
|
t1 = g.create_task(foo1())
|
|
t2 = g.create_task(foo2())
|
|
|
|
self.assertEqual(t1.result(), 42)
|
|
self.assertEqual(t2.result(), 11)
|
|
|
|
async def test_taskgroup_02(self):
|
|
|
|
async def foo1():
|
|
await asyncio.sleep(0.1)
|
|
return 42
|
|
|
|
async def foo2():
|
|
await asyncio.sleep(0.2)
|
|
return 11
|
|
|
|
async with taskgroups.TaskGroup() as g:
|
|
t1 = g.create_task(foo1())
|
|
await asyncio.sleep(0.15)
|
|
t2 = g.create_task(foo2())
|
|
|
|
self.assertEqual(t1.result(), 42)
|
|
self.assertEqual(t2.result(), 11)
|
|
|
|
async def test_taskgroup_03(self):
|
|
|
|
async def foo1():
|
|
await asyncio.sleep(1)
|
|
return 42
|
|
|
|
async def foo2():
|
|
await asyncio.sleep(0.2)
|
|
return 11
|
|
|
|
async with taskgroups.TaskGroup() as g:
|
|
t1 = g.create_task(foo1())
|
|
await asyncio.sleep(0.15)
|
|
# cancel t1 explicitly, i.e. everything should continue
|
|
# working as expected.
|
|
t1.cancel()
|
|
|
|
t2 = g.create_task(foo2())
|
|
|
|
self.assertTrue(t1.cancelled())
|
|
self.assertEqual(t2.result(), 11)
|
|
|
|
async def test_taskgroup_04(self):
|
|
|
|
NUM = 0
|
|
t2_cancel = False
|
|
t2 = None
|
|
|
|
async def foo1():
|
|
await asyncio.sleep(0.1)
|
|
1 / 0
|
|
|
|
async def foo2():
|
|
nonlocal NUM, t2_cancel
|
|
try:
|
|
await asyncio.sleep(1)
|
|
except asyncio.CancelledError:
|
|
t2_cancel = True
|
|
raise
|
|
NUM += 1
|
|
|
|
async def runner():
|
|
nonlocal NUM, t2
|
|
|
|
async with taskgroups.TaskGroup() as g:
|
|
g.create_task(foo1())
|
|
t2 = g.create_task(foo2())
|
|
|
|
NUM += 10
|
|
|
|
with self.assertRaises(ExceptionGroup) as cm:
|
|
await asyncio.create_task(runner())
|
|
|
|
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
|
|
|
|
self.assertEqual(NUM, 0)
|
|
self.assertTrue(t2_cancel)
|
|
self.assertTrue(t2.cancelled())
|
|
|
|
async def test_cancel_children_on_child_error(self):
|
|
# When a child task raises an error, the rest of the children
|
|
# are cancelled and the errors are gathered into an EG.
|
|
|
|
NUM = 0
|
|
t2_cancel = False
|
|
runner_cancel = False
|
|
|
|
async def foo1():
|
|
await asyncio.sleep(0.1)
|
|
1 / 0
|
|
|
|
async def foo2():
|
|
nonlocal NUM, t2_cancel
|
|
try:
|
|
await asyncio.sleep(5)
|
|
except asyncio.CancelledError:
|
|
t2_cancel = True
|
|
raise
|
|
NUM += 1
|
|
|
|
async def runner():
|
|
nonlocal NUM, runner_cancel
|
|
|
|
async with taskgroups.TaskGroup() as g:
|
|
g.create_task(foo1())
|
|
g.create_task(foo1())
|
|
g.create_task(foo1())
|
|
g.create_task(foo2())
|
|
try:
|
|
await asyncio.sleep(10)
|
|
except asyncio.CancelledError:
|
|
runner_cancel = True
|
|
raise
|
|
|
|
NUM += 10
|
|
|
|
# The 3 foo1 sub tasks can be racy when the host is busy - if the
|
|
# cancellation happens in the middle, we'll see partial sub errors here
|
|
with self.assertRaises(ExceptionGroup) as cm:
|
|
await asyncio.create_task(runner())
|
|
|
|
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
|
|
self.assertEqual(NUM, 0)
|
|
self.assertTrue(t2_cancel)
|
|
self.assertTrue(runner_cancel)
|
|
|
|
async def test_cancellation(self):
|
|
|
|
NUM = 0
|
|
|
|
async def foo():
|
|
nonlocal NUM
|
|
try:
|
|
await asyncio.sleep(5)
|
|
except asyncio.CancelledError:
|
|
NUM += 1
|
|
raise
|
|
|
|
async def runner():
|
|
async with taskgroups.TaskGroup() as g:
|
|
for _ in range(5):
|
|
g.create_task(foo())
|
|
|
|
r = asyncio.create_task(runner())
|
|
await asyncio.sleep(0.1)
|
|
|
|
self.assertFalse(r.done())
|
|
r.cancel()
|
|
with self.assertRaises(asyncio.CancelledError) as cm:
|
|
await r
|
|
|
|
self.assertEqual(NUM, 5)
|
|
|
|
async def test_taskgroup_07(self):
|
|
|
|
NUM = 0
|
|
|
|
async def foo():
|
|
nonlocal NUM
|
|
try:
|
|
await asyncio.sleep(5)
|
|
except asyncio.CancelledError:
|
|
NUM += 1
|
|
raise
|
|
|
|
async def runner():
|
|
nonlocal NUM
|
|
async with taskgroups.TaskGroup() as g:
|
|
for _ in range(5):
|
|
g.create_task(foo())
|
|
|
|
try:
|
|
await asyncio.sleep(10)
|
|
except asyncio.CancelledError:
|
|
NUM += 10
|
|
raise
|
|
|
|
r = asyncio.create_task(runner())
|
|
await asyncio.sleep(0.1)
|
|
|
|
self.assertFalse(r.done())
|
|
r.cancel()
|
|
with self.assertRaises(asyncio.CancelledError):
|
|
await r
|
|
|
|
self.assertEqual(NUM, 15)
|
|
|
|
async def test_taskgroup_08(self):
|
|
|
|
async def foo():
|
|
try:
|
|
await asyncio.sleep(10)
|
|
finally:
|
|
1 / 0
|
|
|
|
async def runner():
|
|
async with taskgroups.TaskGroup() as g:
|
|
for _ in range(5):
|
|
g.create_task(foo())
|
|
|
|
await asyncio.sleep(10)
|
|
|
|
r = asyncio.create_task(runner())
|
|
await asyncio.sleep(0.1)
|
|
|
|
self.assertFalse(r.done())
|
|
r.cancel()
|
|
with self.assertRaises(ExceptionGroup) as cm:
|
|
await r
|
|
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
|
|
|
|
async def test_taskgroup_09(self):
|
|
|
|
t1 = t2 = None
|
|
|
|
async def foo1():
|
|
await asyncio.sleep(1)
|
|
return 42
|
|
|
|
async def foo2():
|
|
await asyncio.sleep(2)
|
|
return 11
|
|
|
|
async def runner():
|
|
nonlocal t1, t2
|
|
async with taskgroups.TaskGroup() as g:
|
|
t1 = g.create_task(foo1())
|
|
t2 = g.create_task(foo2())
|
|
await asyncio.sleep(0.1)
|
|
1 / 0
|
|
|
|
try:
|
|
await runner()
|
|
except ExceptionGroup as t:
|
|
self.assertEqual(get_error_types(t), {ZeroDivisionError})
|
|
else:
|
|
self.fail('ExceptionGroup was not raised')
|
|
|
|
self.assertTrue(t1.cancelled())
|
|
self.assertTrue(t2.cancelled())
|
|
|
|
async def test_taskgroup_10(self):
|
|
|
|
t1 = t2 = None
|
|
|
|
async def foo1():
|
|
await asyncio.sleep(1)
|
|
return 42
|
|
|
|
async def foo2():
|
|
await asyncio.sleep(2)
|
|
return 11
|
|
|
|
async def runner():
|
|
nonlocal t1, t2
|
|
async with taskgroups.TaskGroup() as g:
|
|
t1 = g.create_task(foo1())
|
|
t2 = g.create_task(foo2())
|
|
1 / 0
|
|
|
|
try:
|
|
await runner()
|
|
except ExceptionGroup as t:
|
|
self.assertEqual(get_error_types(t), {ZeroDivisionError})
|
|
else:
|
|
self.fail('ExceptionGroup was not raised')
|
|
|
|
self.assertTrue(t1.cancelled())
|
|
self.assertTrue(t2.cancelled())
|
|
|
|
async def test_taskgroup_11(self):
|
|
|
|
async def foo():
|
|
try:
|
|
await asyncio.sleep(10)
|
|
finally:
|
|
1 / 0
|
|
|
|
async def runner():
|
|
async with taskgroups.TaskGroup():
|
|
async with taskgroups.TaskGroup() as g2:
|
|
for _ in range(5):
|
|
g2.create_task(foo())
|
|
|
|
await asyncio.sleep(10)
|
|
|
|
r = asyncio.create_task(runner())
|
|
await asyncio.sleep(0.1)
|
|
|
|
self.assertFalse(r.done())
|
|
r.cancel()
|
|
with self.assertRaises(ExceptionGroup) as cm:
|
|
await r
|
|
|
|
self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
|
|
self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
|
|
|
|
async def test_taskgroup_12(self):
|
|
|
|
async def foo():
|
|
try:
|
|
await asyncio.sleep(10)
|
|
finally:
|
|
1 / 0
|
|
|
|
async def runner():
|
|
async with taskgroups.TaskGroup() as g1:
|
|
g1.create_task(asyncio.sleep(10))
|
|
|
|
async with taskgroups.TaskGroup() as g2:
|
|
for _ in range(5):
|
|
g2.create_task(foo())
|
|
|
|
await asyncio.sleep(10)
|
|
|
|
r = asyncio.create_task(runner())
|
|
await asyncio.sleep(0.1)
|
|
|
|
self.assertFalse(r.done())
|
|
r.cancel()
|
|
with self.assertRaises(ExceptionGroup) as cm:
|
|
await r
|
|
|
|
self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
|
|
self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
|
|
|
|
async def test_taskgroup_13(self):
|
|
|
|
async def crash_after(t):
|
|
await asyncio.sleep(t)
|
|
raise ValueError(t)
|
|
|
|
async def runner():
|
|
async with taskgroups.TaskGroup() as g1:
|
|
g1.create_task(crash_after(0.1))
|
|
|
|
async with taskgroups.TaskGroup() as g2:
|
|
g2.create_task(crash_after(10))
|
|
|
|
r = asyncio.create_task(runner())
|
|
with self.assertRaises(ExceptionGroup) as cm:
|
|
await r
|
|
|
|
self.assertEqual(get_error_types(cm.exception), {ValueError})
|
|
|
|
async def test_taskgroup_14(self):
|
|
|
|
async def crash_after(t):
|
|
await asyncio.sleep(t)
|
|
raise ValueError(t)
|
|
|
|
async def runner():
|
|
async with taskgroups.TaskGroup() as g1:
|
|
g1.create_task(crash_after(10))
|
|
|
|
async with taskgroups.TaskGroup() as g2:
|
|
g2.create_task(crash_after(0.1))
|
|
|
|
r = asyncio.create_task(runner())
|
|
with self.assertRaises(ExceptionGroup) as cm:
|
|
await r
|
|
|
|
self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
|
|
self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ValueError})
|
|
|
|
async def test_taskgroup_15(self):
|
|
|
|
async def crash_soon():
|
|
await asyncio.sleep(0.3)
|
|
1 / 0
|
|
|
|
async def runner():
|
|
async with taskgroups.TaskGroup() as g1:
|
|
g1.create_task(crash_soon())
|
|
try:
|
|
await asyncio.sleep(10)
|
|
except asyncio.CancelledError:
|
|
await asyncio.sleep(0.5)
|
|
raise
|
|
|
|
r = asyncio.create_task(runner())
|
|
await asyncio.sleep(0.1)
|
|
|
|
self.assertFalse(r.done())
|
|
r.cancel()
|
|
with self.assertRaises(ExceptionGroup) as cm:
|
|
await r
|
|
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
|
|
|
|
async def test_taskgroup_16(self):
|
|
|
|
async def crash_soon():
|
|
await asyncio.sleep(0.3)
|
|
1 / 0
|
|
|
|
async def nested_runner():
|
|
async with taskgroups.TaskGroup() as g1:
|
|
g1.create_task(crash_soon())
|
|
try:
|
|
await asyncio.sleep(10)
|
|
except asyncio.CancelledError:
|
|
await asyncio.sleep(0.5)
|
|
raise
|
|
|
|
async def runner():
|
|
t = asyncio.create_task(nested_runner())
|
|
await t
|
|
|
|
r = asyncio.create_task(runner())
|
|
await asyncio.sleep(0.1)
|
|
|
|
self.assertFalse(r.done())
|
|
r.cancel()
|
|
with self.assertRaises(ExceptionGroup) as cm:
|
|
await r
|
|
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
|
|
|
|
async def test_taskgroup_17(self):
|
|
NUM = 0
|
|
|
|
async def runner():
|
|
nonlocal NUM
|
|
async with taskgroups.TaskGroup():
|
|
try:
|
|
await asyncio.sleep(10)
|
|
except asyncio.CancelledError:
|
|
NUM += 10
|
|
raise
|
|
|
|
r = asyncio.create_task(runner())
|
|
await asyncio.sleep(0.1)
|
|
|
|
self.assertFalse(r.done())
|
|
r.cancel()
|
|
with self.assertRaises(asyncio.CancelledError):
|
|
await r
|
|
|
|
self.assertEqual(NUM, 10)
|
|
|
|
async def test_taskgroup_18(self):
|
|
NUM = 0
|
|
|
|
async def runner():
|
|
nonlocal NUM
|
|
async with taskgroups.TaskGroup():
|
|
try:
|
|
await asyncio.sleep(10)
|
|
except asyncio.CancelledError:
|
|
NUM += 10
|
|
# This isn't a good idea, but we have to support
|
|
# this weird case.
|
|
raise MyExc
|
|
|
|
r = asyncio.create_task(runner())
|
|
await asyncio.sleep(0.1)
|
|
|
|
self.assertFalse(r.done())
|
|
r.cancel()
|
|
|
|
try:
|
|
await r
|
|
except ExceptionGroup as t:
|
|
self.assertEqual(get_error_types(t),{MyExc})
|
|
else:
|
|
self.fail('ExceptionGroup was not raised')
|
|
|
|
self.assertEqual(NUM, 10)
|
|
|
|
async def test_taskgroup_19(self):
|
|
async def crash_soon():
|
|
await asyncio.sleep(0.1)
|
|
1 / 0
|
|
|
|
async def nested():
|
|
try:
|
|
await asyncio.sleep(10)
|
|
finally:
|
|
raise MyExc
|
|
|
|
async def runner():
|
|
async with taskgroups.TaskGroup() as g:
|
|
g.create_task(crash_soon())
|
|
await nested()
|
|
|
|
r = asyncio.create_task(runner())
|
|
try:
|
|
await r
|
|
except ExceptionGroup as t:
|
|
self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError})
|
|
else:
|
|
self.fail('TasgGroupError was not raised')
|
|
|
|
async def test_taskgroup_20(self):
|
|
async def crash_soon():
|
|
await asyncio.sleep(0.1)
|
|
1 / 0
|
|
|
|
async def nested():
|
|
try:
|
|
await asyncio.sleep(10)
|
|
finally:
|
|
raise KeyboardInterrupt
|
|
|
|
async def runner():
|
|
async with taskgroups.TaskGroup() as g:
|
|
g.create_task(crash_soon())
|
|
await nested()
|
|
|
|
with self.assertRaises(KeyboardInterrupt):
|
|
await runner()
|
|
|
|
async def test_taskgroup_20a(self):
|
|
async def crash_soon():
|
|
await asyncio.sleep(0.1)
|
|
1 / 0
|
|
|
|
async def nested():
|
|
try:
|
|
await asyncio.sleep(10)
|
|
finally:
|
|
raise MyBaseExc
|
|
|
|
async def runner():
|
|
async with taskgroups.TaskGroup() as g:
|
|
g.create_task(crash_soon())
|
|
await nested()
|
|
|
|
with self.assertRaises(BaseExceptionGroup) as cm:
|
|
await runner()
|
|
|
|
self.assertEqual(
|
|
get_error_types(cm.exception), {MyBaseExc, ZeroDivisionError}
|
|
)
|
|
|
|
async def _test_taskgroup_21(self):
|
|
# This test doesn't work as asyncio, currently, doesn't
|
|
# correctly propagate KeyboardInterrupt (or SystemExit) --
|
|
# those cause the event loop itself to crash.
|
|
# (Compare to the previous (passing) test -- that one raises
|
|
# a plain exception but raises KeyboardInterrupt in nested();
|
|
# this test does it the other way around.)
|
|
|
|
async def crash_soon():
|
|
await asyncio.sleep(0.1)
|
|
raise KeyboardInterrupt
|
|
|
|
async def nested():
|
|
try:
|
|
await asyncio.sleep(10)
|
|
finally:
|
|
raise TypeError
|
|
|
|
async def runner():
|
|
async with taskgroups.TaskGroup() as g:
|
|
g.create_task(crash_soon())
|
|
await nested()
|
|
|
|
with self.assertRaises(KeyboardInterrupt):
|
|
await runner()
|
|
|
|
async def test_taskgroup_21a(self):
|
|
|
|
async def crash_soon():
|
|
await asyncio.sleep(0.1)
|
|
raise MyBaseExc
|
|
|
|
async def nested():
|
|
try:
|
|
await asyncio.sleep(10)
|
|
finally:
|
|
raise TypeError
|
|
|
|
async def runner():
|
|
async with taskgroups.TaskGroup() as g:
|
|
g.create_task(crash_soon())
|
|
await nested()
|
|
|
|
with self.assertRaises(BaseExceptionGroup) as cm:
|
|
await runner()
|
|
|
|
self.assertEqual(get_error_types(cm.exception), {MyBaseExc, TypeError})
|
|
|
|
async def test_taskgroup_22(self):
|
|
|
|
async def foo1():
|
|
await asyncio.sleep(1)
|
|
return 42
|
|
|
|
async def foo2():
|
|
await asyncio.sleep(2)
|
|
return 11
|
|
|
|
async def runner():
|
|
async with taskgroups.TaskGroup() as g:
|
|
g.create_task(foo1())
|
|
g.create_task(foo2())
|
|
|
|
r = asyncio.create_task(runner())
|
|
await asyncio.sleep(0.05)
|
|
r.cancel()
|
|
|
|
with self.assertRaises(asyncio.CancelledError):
|
|
await r
|
|
|
|
async def test_taskgroup_23(self):
|
|
|
|
async def do_job(delay):
|
|
await asyncio.sleep(delay)
|
|
|
|
async with taskgroups.TaskGroup() as g:
|
|
for count in range(10):
|
|
await asyncio.sleep(0.1)
|
|
g.create_task(do_job(0.3))
|
|
if count == 5:
|
|
self.assertLess(len(g._tasks), 5)
|
|
await asyncio.sleep(1.35)
|
|
self.assertEqual(len(g._tasks), 0)
|
|
|
|
async def test_taskgroup_24(self):
|
|
|
|
async def root(g):
|
|
await asyncio.sleep(0.1)
|
|
g.create_task(coro1(0.1))
|
|
g.create_task(coro1(0.2))
|
|
|
|
async def coro1(delay):
|
|
await asyncio.sleep(delay)
|
|
|
|
async def runner():
|
|
async with taskgroups.TaskGroup() as g:
|
|
g.create_task(root(g))
|
|
|
|
await runner()
|
|
|
|
async def test_taskgroup_25(self):
|
|
nhydras = 0
|
|
|
|
async def hydra(g):
|
|
nonlocal nhydras
|
|
nhydras += 1
|
|
await asyncio.sleep(0.01)
|
|
g.create_task(hydra(g))
|
|
g.create_task(hydra(g))
|
|
|
|
async def hercules():
|
|
while nhydras < 10:
|
|
await asyncio.sleep(0.015)
|
|
1 / 0
|
|
|
|
async def runner():
|
|
async with taskgroups.TaskGroup() as g:
|
|
g.create_task(hydra(g))
|
|
g.create_task(hercules())
|
|
|
|
with self.assertRaises(ExceptionGroup) as cm:
|
|
await runner()
|
|
|
|
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
|
|
self.assertGreaterEqual(nhydras, 10)
|
|
|
|
async def test_taskgroup_task_name(self):
|
|
async def coro():
|
|
await asyncio.sleep(0)
|
|
async with taskgroups.TaskGroup() as g:
|
|
t = g.create_task(coro(), name="yolo")
|
|
self.assertEqual(t.get_name(), "yolo")
|
|
|
|
async def test_taskgroup_task_context(self):
|
|
cvar = contextvars.ContextVar('cvar')
|
|
|
|
async def coro(val):
|
|
await asyncio.sleep(0)
|
|
cvar.set(val)
|
|
|
|
async with taskgroups.TaskGroup() as g:
|
|
ctx = contextvars.copy_context()
|
|
self.assertIsNone(ctx.get(cvar))
|
|
t1 = g.create_task(coro(1), context=ctx)
|
|
await t1
|
|
self.assertEqual(1, ctx.get(cvar))
|
|
t2 = g.create_task(coro(2), context=ctx)
|
|
await t2
|
|
self.assertEqual(2, ctx.get(cvar))
|
|
|
|
async def test_taskgroup_no_create_task_after_failure(self):
|
|
async def coro1():
|
|
await asyncio.sleep(0.001)
|
|
1 / 0
|
|
async def coro2(g):
|
|
try:
|
|
await asyncio.sleep(1)
|
|
except asyncio.CancelledError:
|
|
with self.assertRaises(RuntimeError):
|
|
g.create_task(coro1())
|
|
|
|
with self.assertRaises(ExceptionGroup) as cm:
|
|
async with taskgroups.TaskGroup() as g:
|
|
g.create_task(coro1())
|
|
g.create_task(coro2(g))
|
|
|
|
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
|
|
|
|
async def test_taskgroup_context_manager_exit_raises(self):
|
|
# See https://github.com/python/cpython/issues/95289
|
|
class CustomException(Exception):
|
|
pass
|
|
|
|
async def raise_exc():
|
|
raise CustomException
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def database():
|
|
try:
|
|
yield
|
|
finally:
|
|
raise CustomException
|
|
|
|
async def main():
|
|
task = asyncio.current_task()
|
|
try:
|
|
async with taskgroups.TaskGroup() as tg:
|
|
async with database():
|
|
tg.create_task(raise_exc())
|
|
await asyncio.sleep(1)
|
|
except* CustomException as err:
|
|
self.assertEqual(task.cancelling(), 0)
|
|
self.assertEqual(len(err.exceptions), 2)
|
|
|
|
else:
|
|
self.fail('CustomException not raised')
|
|
|
|
await asyncio.create_task(main())
|
|
|
|
async def test_taskgroup_already_entered(self):
|
|
tg = taskgroups.TaskGroup()
|
|
async with tg:
|
|
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
|
|
async with tg:
|
|
pass
|
|
|
|
async def test_taskgroup_double_enter(self):
|
|
tg = taskgroups.TaskGroup()
|
|
async with tg:
|
|
pass
|
|
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
|
|
async with tg:
|
|
pass
|
|
|
|
async def test_taskgroup_finished(self):
|
|
async def create_task_after_tg_finish():
|
|
tg = taskgroups.TaskGroup()
|
|
async with tg:
|
|
pass
|
|
coro = asyncio.sleep(0)
|
|
with self.assertRaisesRegex(RuntimeError, "is finished"):
|
|
tg.create_task(coro)
|
|
|
|
# Make sure the coroutine was closed when submitted to the inactive tg
|
|
# (if not closed, a RuntimeWarning should have been raised)
|
|
with warnings.catch_warnings(record=True) as w:
|
|
await create_task_after_tg_finish()
|
|
self.assertEqual(len(w), 0)
|
|
|
|
async def test_taskgroup_not_entered(self):
|
|
tg = taskgroups.TaskGroup()
|
|
coro = asyncio.sleep(0)
|
|
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
|
|
tg.create_task(coro)
|
|
|
|
async def test_taskgroup_without_parent_task(self):
|
|
tg = taskgroups.TaskGroup()
|
|
with self.assertRaisesRegex(RuntimeError, "parent task"):
|
|
await await_without_task(tg.__aenter__())
|
|
coro = asyncio.sleep(0)
|
|
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
|
|
tg.create_task(coro)
|
|
|
|
def test_coro_closed_when_tg_closed(self):
|
|
async def run_coro_after_tg_closes():
|
|
async with taskgroups.TaskGroup() as tg:
|
|
pass
|
|
coro = asyncio.sleep(0)
|
|
with self.assertRaisesRegex(RuntimeError, "is finished"):
|
|
tg.create_task(coro)
|
|
loop = asyncio.get_event_loop()
|
|
loop.run_until_complete(run_coro_after_tg_closes())
|
|
|
|
async def test_cancelling_level_preserved(self):
|
|
async def raise_after(t, e):
|
|
await asyncio.sleep(t)
|
|
raise e()
|
|
|
|
try:
|
|
async with asyncio.TaskGroup() as tg:
|
|
tg.create_task(raise_after(0.0, RuntimeError))
|
|
except* RuntimeError:
|
|
pass
|
|
self.assertEqual(asyncio.current_task().cancelling(), 0)
|
|
|
|
async def test_nested_groups_both_cancelled(self):
|
|
async def raise_after(t, e):
|
|
await asyncio.sleep(t)
|
|
raise e()
|
|
|
|
try:
|
|
async with asyncio.TaskGroup() as outer_tg:
|
|
try:
|
|
async with asyncio.TaskGroup() as inner_tg:
|
|
inner_tg.create_task(raise_after(0, RuntimeError))
|
|
outer_tg.create_task(raise_after(0, ValueError))
|
|
except* RuntimeError:
|
|
pass
|
|
else:
|
|
self.fail("RuntimeError not raised")
|
|
self.assertEqual(asyncio.current_task().cancelling(), 1)
|
|
except* ValueError:
|
|
pass
|
|
else:
|
|
self.fail("ValueError not raised")
|
|
self.assertEqual(asyncio.current_task().cancelling(), 0)
|
|
|
|
async def test_error_and_cancel(self):
|
|
event = asyncio.Event()
|
|
|
|
async def raise_error():
|
|
event.set()
|
|
await asyncio.sleep(0)
|
|
raise RuntimeError()
|
|
|
|
async def inner():
|
|
try:
|
|
async with taskgroups.TaskGroup() as tg:
|
|
tg.create_task(raise_error())
|
|
await asyncio.sleep(1)
|
|
self.fail("Sleep in group should have been cancelled")
|
|
except* RuntimeError:
|
|
self.assertEqual(asyncio.current_task().cancelling(), 1)
|
|
self.assertEqual(asyncio.current_task().cancelling(), 1)
|
|
await asyncio.sleep(1)
|
|
self.fail("Sleep after group should have been cancelled")
|
|
|
|
async def outer():
|
|
t = asyncio.create_task(inner())
|
|
await event.wait()
|
|
self.assertEqual(t.cancelling(), 0)
|
|
t.cancel()
|
|
self.assertEqual(t.cancelling(), 1)
|
|
with self.assertRaises(asyncio.CancelledError):
|
|
await t
|
|
self.assertTrue(t.cancelled())
|
|
|
|
await outer()
|
|
|
|
async def test_exception_refcycles_direct(self):
|
|
"""Test that TaskGroup doesn't keep a reference to the raised ExceptionGroup"""
|
|
tg = asyncio.TaskGroup()
|
|
exc = None
|
|
|
|
class _Done(Exception):
|
|
pass
|
|
|
|
try:
|
|
async with tg:
|
|
raise _Done
|
|
except ExceptionGroup as e:
|
|
exc = e
|
|
|
|
self.assertIsNotNone(exc)
|
|
self.assertListEqual(gc.get_referrers(exc), no_other_refs())
|
|
|
|
|
|
async def test_exception_refcycles_errors(self):
|
|
"""Test that TaskGroup deletes self._errors, and __aexit__ args"""
|
|
tg = asyncio.TaskGroup()
|
|
exc = None
|
|
|
|
class _Done(Exception):
|
|
pass
|
|
|
|
try:
|
|
async with tg:
|
|
raise _Done
|
|
except* _Done as excs:
|
|
exc = excs.exceptions[0]
|
|
|
|
self.assertIsInstance(exc, _Done)
|
|
self.assertListEqual(gc.get_referrers(exc), no_other_refs())
|
|
|
|
|
|
async def test_exception_refcycles_parent_task(self):
|
|
"""Test that TaskGroup deletes self._parent_task"""
|
|
tg = asyncio.TaskGroup()
|
|
exc = None
|
|
|
|
class _Done(Exception):
|
|
pass
|
|
|
|
async def coro_fn():
|
|
async with tg:
|
|
raise _Done
|
|
|
|
try:
|
|
async with asyncio.TaskGroup() as tg2:
|
|
tg2.create_task(coro_fn())
|
|
except* _Done as excs:
|
|
exc = excs.exceptions[0].exceptions[0]
|
|
|
|
self.assertIsInstance(exc, _Done)
|
|
self.assertListEqual(gc.get_referrers(exc), no_other_refs())
|
|
|
|
async def test_exception_refcycles_propagate_cancellation_error(self):
|
|
"""Test that TaskGroup deletes propagate_cancellation_error"""
|
|
tg = asyncio.TaskGroup()
|
|
exc = None
|
|
|
|
try:
|
|
async with asyncio.timeout(-1):
|
|
async with tg:
|
|
await asyncio.sleep(0)
|
|
except TimeoutError as e:
|
|
exc = e.__cause__
|
|
|
|
self.assertIsInstance(exc, asyncio.CancelledError)
|
|
self.assertListEqual(gc.get_referrers(exc), no_other_refs())
|
|
|
|
async def test_exception_refcycles_base_error(self):
|
|
"""Test that TaskGroup deletes self._base_error"""
|
|
class MyKeyboardInterrupt(KeyboardInterrupt):
|
|
pass
|
|
|
|
tg = asyncio.TaskGroup()
|
|
exc = None
|
|
|
|
try:
|
|
async with tg:
|
|
raise MyKeyboardInterrupt
|
|
except MyKeyboardInterrupt as e:
|
|
exc = e
|
|
|
|
self.assertIsNotNone(exc)
|
|
self.assertListEqual(gc.get_referrers(exc), no_other_refs())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|