import asyncio from contextlib import ( asynccontextmanager, AbstractAsyncContextManager, AsyncExitStack, nullcontext, aclosing, contextmanager) import functools from test import support import unittest import traceback from test.test_contextlib import TestBaseExitStack support.requires_working_socket(module=True) def _async_test(func): """Decorator to turn an async function into a test case.""" @functools.wraps(func) def wrapper(*args, **kwargs): coro = func(*args, **kwargs) asyncio.run(coro) return wrapper def tearDownModule(): asyncio.set_event_loop_policy(None) class TestAbstractAsyncContextManager(unittest.TestCase): @_async_test async def test_enter(self): class DefaultEnter(AbstractAsyncContextManager): async def __aexit__(self, *args): await super().__aexit__(*args) manager = DefaultEnter() self.assertIs(await manager.__aenter__(), manager) async with manager as context: self.assertIs(manager, context) @_async_test async def test_slots(self): class DefaultAsyncContextManager(AbstractAsyncContextManager): __slots__ = () async def __aexit__(self, *args): await super().__aexit__(*args) with self.assertRaises(AttributeError): manager = DefaultAsyncContextManager() manager.var = 42 @_async_test async def test_async_gen_propagates_generator_exit(self): # A regression test for https://bugs.python.org/issue33786. @asynccontextmanager async def ctx(): yield async def gen(): async with ctx(): yield 11 ret = [] exc = ValueError(22) with self.assertRaises(ValueError): async with ctx(): async for val in gen(): ret.append(val) raise exc self.assertEqual(ret, [11]) def test_exit_is_abstract(self): class MissingAexit(AbstractAsyncContextManager): pass with self.assertRaises(TypeError): MissingAexit() def test_structural_subclassing(self): class ManagerFromScratch: async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_value, traceback): return None self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager)) class DefaultEnter(AbstractAsyncContextManager): async def __aexit__(self, *args): await super().__aexit__(*args) self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager)) class NoneAenter(ManagerFromScratch): __aenter__ = None self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager)) class NoneAexit(ManagerFromScratch): __aexit__ = None self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager)) class AsyncContextManagerTestCase(unittest.TestCase): @_async_test async def test_contextmanager_plain(self): state = [] @asynccontextmanager async def woohoo(): state.append(1) yield 42 state.append(999) async with woohoo() as x: self.assertEqual(state, [1]) self.assertEqual(x, 42) state.append(x) self.assertEqual(state, [1, 42, 999]) @_async_test async def test_contextmanager_finally(self): state = [] @asynccontextmanager async def woohoo(): state.append(1) try: yield 42 finally: state.append(999) with self.assertRaises(ZeroDivisionError): async with woohoo() as x: self.assertEqual(state, [1]) self.assertEqual(x, 42) state.append(x) raise ZeroDivisionError() self.assertEqual(state, [1, 42, 999]) @_async_test async def test_contextmanager_traceback(self): @asynccontextmanager async def f(): yield try: async with f(): 1/0 except ZeroDivisionError as e: frames = traceback.extract_tb(e.__traceback__) self.assertEqual(len(frames), 1) self.assertEqual(frames[0].name, 'test_contextmanager_traceback') self.assertEqual(frames[0].line, '1/0') # Repeat with RuntimeError (which goes through a different code path) class RuntimeErrorSubclass(RuntimeError): pass try: async with f(): raise RuntimeErrorSubclass(42) except RuntimeErrorSubclass as e: frames = traceback.extract_tb(e.__traceback__) self.assertEqual(len(frames), 1) self.assertEqual(frames[0].name, 'test_contextmanager_traceback') self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)') class StopIterationSubclass(StopIteration): pass class StopAsyncIterationSubclass(StopAsyncIteration): pass for stop_exc in ( StopIteration('spam'), StopAsyncIteration('ham'), StopIterationSubclass('spam'), StopAsyncIterationSubclass('spam') ): with self.subTest(type=type(stop_exc)): try: async with f(): raise stop_exc except type(stop_exc) as e: self.assertIs(e, stop_exc) frames = traceback.extract_tb(e.__traceback__) else: self.fail(f'{stop_exc} was suppressed') self.assertEqual(len(frames), 1) self.assertEqual(frames[0].name, 'test_contextmanager_traceback') self.assertEqual(frames[0].line, 'raise stop_exc') @_async_test async def test_contextmanager_no_reraise(self): @asynccontextmanager async def whee(): yield ctx = whee() await ctx.__aenter__() # Calling __aexit__ should not result in an exception self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None)) @_async_test async def test_contextmanager_trap_yield_after_throw(self): @asynccontextmanager async def whoo(): try: yield except: yield ctx = whoo() await ctx.__aenter__() with self.assertRaises(RuntimeError): await ctx.__aexit__(TypeError, TypeError('foo'), None) @_async_test async def test_contextmanager_trap_no_yield(self): @asynccontextmanager async def whoo(): if False: yield ctx = whoo() with self.assertRaises(RuntimeError): await ctx.__aenter__() @_async_test async def test_contextmanager_trap_second_yield(self): @asynccontextmanager async def whoo(): yield yield ctx = whoo() await ctx.__aenter__() with self.assertRaises(RuntimeError): await ctx.__aexit__(None, None, None) @_async_test async def test_contextmanager_non_normalised(self): @asynccontextmanager async def whoo(): try: yield except RuntimeError: raise SyntaxError ctx = whoo() await ctx.__aenter__() with self.assertRaises(SyntaxError): await ctx.__aexit__(RuntimeError, None, None) @_async_test async def test_contextmanager_except(self): state = [] @asynccontextmanager async def woohoo(): state.append(1) try: yield 42 except ZeroDivisionError as e: state.append(e.args[0]) self.assertEqual(state, [1, 42, 999]) async with woohoo() as x: self.assertEqual(state, [1]) self.assertEqual(x, 42) state.append(x) raise ZeroDivisionError(999) self.assertEqual(state, [1, 42, 999]) @_async_test async def test_contextmanager_except_stopiter(self): @asynccontextmanager async def woohoo(): yield class StopIterationSubclass(StopIteration): pass class StopAsyncIterationSubclass(StopAsyncIteration): pass for stop_exc in ( StopIteration('spam'), StopAsyncIteration('ham'), StopIterationSubclass('spam'), StopAsyncIterationSubclass('spam') ): with self.subTest(type=type(stop_exc)): try: async with woohoo(): raise stop_exc except Exception as ex: self.assertIs(ex, stop_exc) else: self.fail(f'{stop_exc} was suppressed') @_async_test async def test_contextmanager_wrap_runtimeerror(self): @asynccontextmanager async def woohoo(): try: yield except Exception as exc: raise RuntimeError(f'caught {exc}') from exc with self.assertRaises(RuntimeError): async with woohoo(): 1 / 0 # If the context manager wrapped StopAsyncIteration in a RuntimeError, # we also unwrap it, because we can't tell whether the wrapping was # done by the generator machinery or by the generator itself. with self.assertRaises(StopAsyncIteration): async with woohoo(): raise StopAsyncIteration def _create_contextmanager_attribs(self): def attribs(**kw): def decorate(func): for k,v in kw.items(): setattr(func,k,v) return func return decorate @asynccontextmanager @attribs(foo='bar') async def baz(spam): """Whee!""" yield return baz def test_contextmanager_attribs(self): baz = self._create_contextmanager_attribs() self.assertEqual(baz.__name__,'baz') self.assertEqual(baz.foo, 'bar') @support.requires_docstrings def test_contextmanager_doc_attrib(self): baz = self._create_contextmanager_attribs() self.assertEqual(baz.__doc__, "Whee!") @support.requires_docstrings @_async_test async def test_instance_docstring_given_cm_docstring(self): baz = self._create_contextmanager_attribs()(None) self.assertEqual(baz.__doc__, "Whee!") async with baz: pass # suppress warning @_async_test async def test_keywords(self): # Ensure no keyword arguments are inhibited @asynccontextmanager async def woohoo(self, func, args, kwds): yield (self, func, args, kwds) async with woohoo(self=11, func=22, args=33, kwds=44) as target: self.assertEqual(target, (11, 22, 33, 44)) @_async_test async def test_recursive(self): depth = 0 ncols = 0 @asynccontextmanager async def woohoo(): nonlocal ncols ncols += 1 nonlocal depth before = depth depth += 1 yield depth -= 1 self.assertEqual(depth, before) @woohoo() async def recursive(): if depth < 10: await recursive() await recursive() self.assertEqual(ncols, 10) self.assertEqual(depth, 0) @_async_test async def test_decorator(self): entered = False @asynccontextmanager async def context(): nonlocal entered entered = True yield entered = False @context() async def test(): self.assertTrue(entered) self.assertFalse(entered) await test() self.assertFalse(entered) @_async_test async def test_decorator_with_exception(self): entered = False @asynccontextmanager async def context(): nonlocal entered try: entered = True yield finally: entered = False @context() async def test(): self.assertTrue(entered) raise NameError('foo') self.assertFalse(entered) with self.assertRaisesRegex(NameError, 'foo'): await test() self.assertFalse(entered) @_async_test async def test_decorating_method(self): @asynccontextmanager async def context(): yield class Test(object): @context() async def method(self, a, b, c=None): self.a = a self.b = b self.c = c # these tests are for argument passing when used as a decorator test = Test() await test.method(1, 2) self.assertEqual(test.a, 1) self.assertEqual(test.b, 2) self.assertEqual(test.c, None) test = Test() await test.method('a', 'b', 'c') self.assertEqual(test.a, 'a') self.assertEqual(test.b, 'b') self.assertEqual(test.c, 'c') test = Test() await test.method(a=1, b=2) self.assertEqual(test.a, 1) self.assertEqual(test.b, 2) class AclosingTestCase(unittest.TestCase): @support.requires_docstrings def test_instance_docs(self): cm_docstring = aclosing.__doc__ obj = aclosing(None) self.assertEqual(obj.__doc__, cm_docstring) @_async_test async def test_aclosing(self): state = [] class C: async def aclose(self): state.append(1) x = C() self.assertEqual(state, []) async with aclosing(x) as y: self.assertEqual(x, y) self.assertEqual(state, [1]) @_async_test async def test_aclosing_error(self): state = [] class C: async def aclose(self): state.append(1) x = C() self.assertEqual(state, []) with self.assertRaises(ZeroDivisionError): async with aclosing(x) as y: self.assertEqual(x, y) 1 / 0 self.assertEqual(state, [1]) @_async_test async def test_aclosing_bpo41229(self): state = [] @contextmanager def sync_resource(): try: yield finally: state.append(1) async def agenfunc(): with sync_resource(): yield -1 yield -2 x = agenfunc() self.assertEqual(state, []) with self.assertRaises(ZeroDivisionError): async with aclosing(x) as y: self.assertEqual(x, y) self.assertEqual(-1, await x.__anext__()) 1 / 0 self.assertEqual(state, [1]) class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase): class SyncAsyncExitStack(AsyncExitStack): @staticmethod def run_coroutine(coro): loop = asyncio.get_event_loop_policy().get_event_loop() t = loop.create_task(coro) t.add_done_callback(lambda f: loop.stop()) loop.run_forever() exc = t.exception() if not exc: return t.result() else: context = exc.__context__ try: raise exc except: exc.__context__ = context raise exc def close(self): return self.run_coroutine(self.aclose()) def __enter__(self): return self.run_coroutine(self.__aenter__()) def __exit__(self, *exc_details): return self.run_coroutine(self.__aexit__(*exc_details)) exit_stack = SyncAsyncExitStack callback_error_internal_frames = [ ('__exit__', 'return self.run_coroutine(self.__aexit__(*exc_details))'), ('run_coroutine', 'raise exc'), ('run_coroutine', 'raise exc'), ('__aexit__', 'raise exc'), ('__aexit__', 'cb_suppress = cb(*exc_details)'), ] def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) self.addCleanup(self.loop.close) self.addCleanup(asyncio.set_event_loop_policy, None) @_async_test async def test_async_callback(self): expected = [ ((), {}), ((1,), {}), ((1,2), {}), ((), dict(example=1)), ((1,), dict(example=1)), ((1,2), dict(example=1)), ] result = [] async def _exit(*args, **kwds): """Test metadata propagation""" result.append((args, kwds)) async with AsyncExitStack() as stack: for args, kwds in reversed(expected): if args and kwds: f = stack.push_async_callback(_exit, *args, **kwds) elif args: f = stack.push_async_callback(_exit, *args) elif kwds: f = stack.push_async_callback(_exit, **kwds) else: f = stack.push_async_callback(_exit) self.assertIs(f, _exit) for wrapper in stack._exit_callbacks: self.assertIs(wrapper[1].__wrapped__, _exit) self.assertNotEqual(wrapper[1].__name__, _exit.__name__) self.assertIsNone(wrapper[1].__doc__, _exit.__doc__) self.assertEqual(result, expected) result = [] async with AsyncExitStack() as stack: with self.assertRaises(TypeError): stack.push_async_callback(arg=1) with self.assertRaises(TypeError): self.exit_stack.push_async_callback(arg=2) with self.assertRaises(TypeError): stack.push_async_callback(callback=_exit, arg=3) self.assertEqual(result, []) @_async_test async def test_async_push(self): exc_raised = ZeroDivisionError async def _expect_exc(exc_type, exc, exc_tb): self.assertIs(exc_type, exc_raised) async def _suppress_exc(*exc_details): return True async def _expect_ok(exc_type, exc, exc_tb): self.assertIsNone(exc_type) self.assertIsNone(exc) self.assertIsNone(exc_tb) class ExitCM(object): def __init__(self, check_exc): self.check_exc = check_exc async def __aenter__(self): self.fail("Should not be called!") async def __aexit__(self, *exc_details): await self.check_exc(*exc_details) async with self.exit_stack() as stack: stack.push_async_exit(_expect_ok) self.assertIs(stack._exit_callbacks[-1][1], _expect_ok) cm = ExitCM(_expect_ok) stack.push_async_exit(cm) self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) stack.push_async_exit(_suppress_exc) self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc) cm = ExitCM(_expect_exc) stack.push_async_exit(cm) self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) stack.push_async_exit(_expect_exc) self.assertIs(stack._exit_callbacks[-1][1], _expect_exc) stack.push_async_exit(_expect_exc) self.assertIs(stack._exit_callbacks[-1][1], _expect_exc) 1/0 @_async_test async def test_enter_async_context(self): class TestCM(object): async def __aenter__(self): result.append(1) async def __aexit__(self, *exc_details): result.append(3) result = [] cm = TestCM() async with AsyncExitStack() as stack: @stack.push_async_callback # Registered first => cleaned up last async def _exit(): result.append(4) self.assertIsNotNone(_exit) await stack.enter_async_context(cm) self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) result.append(2) self.assertEqual(result, [1, 2, 3, 4]) @_async_test async def test_enter_async_context_errors(self): class LacksEnterAndExit: pass class LacksEnter: async def __aexit__(self, *exc_info): pass class LacksExit: async def __aenter__(self): pass async with self.exit_stack() as stack: with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): await stack.enter_async_context(LacksEnterAndExit()) with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): await stack.enter_async_context(LacksEnter()) with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): await stack.enter_async_context(LacksExit()) self.assertFalse(stack._exit_callbacks) @_async_test async def test_async_exit_exception_chaining(self): # Ensure exception chaining matches the reference behaviour async def raise_exc(exc): raise exc saved_details = None async def suppress_exc(*exc_details): nonlocal saved_details saved_details = exc_details return True try: async with self.exit_stack() as stack: stack.push_async_callback(raise_exc, IndexError) stack.push_async_callback(raise_exc, KeyError) stack.push_async_callback(raise_exc, AttributeError) stack.push_async_exit(suppress_exc) stack.push_async_callback(raise_exc, ValueError) 1 / 0 except IndexError as exc: self.assertIsInstance(exc.__context__, KeyError) self.assertIsInstance(exc.__context__.__context__, AttributeError) # Inner exceptions were suppressed self.assertIsNone(exc.__context__.__context__.__context__) else: self.fail("Expected IndexError, but no exception was raised") # Check the inner exceptions inner_exc = saved_details[1] self.assertIsInstance(inner_exc, ValueError) self.assertIsInstance(inner_exc.__context__, ZeroDivisionError) @_async_test async def test_async_exit_exception_explicit_none_context(self): # Ensure AsyncExitStack chaining matches actual nested `with` statements # regarding explicit __context__ = None. class MyException(Exception): pass @asynccontextmanager async def my_cm(): try: yield except BaseException: exc = MyException() try: raise exc finally: exc.__context__ = None @asynccontextmanager async def my_cm_with_exit_stack(): async with self.exit_stack() as stack: await stack.enter_async_context(my_cm()) yield stack for cm in (my_cm, my_cm_with_exit_stack): with self.subTest(): try: async with cm(): raise IndexError() except MyException as exc: self.assertIsNone(exc.__context__) else: self.fail("Expected IndexError, but no exception was raised") @_async_test async def test_instance_bypass_async(self): class Example(object): pass cm = Example() cm.__aenter__ = object() cm.__aexit__ = object() stack = self.exit_stack() with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): await stack.enter_async_context(cm) stack.push_async_exit(cm) self.assertIs(stack._exit_callbacks[-1][1], cm) class TestAsyncNullcontext(unittest.TestCase): @_async_test async def test_async_nullcontext(self): class C: pass c = C() async with nullcontext(c) as c_in: self.assertIs(c_in, c) if __name__ == '__main__': unittest.main()