bpo-29679: Implement @contextlib.asynccontextmanager (#360)
This commit is contained in:
parent
9dc2b3809f
commit
2e624690bd
|
@ -80,6 +80,36 @@ Functions and classes provided:
|
|||
Use of :class:`ContextDecorator`.
|
||||
|
||||
|
||||
.. decorator:: asynccontextmanager
|
||||
|
||||
Similar to :func:`~contextlib.contextmanager`, but creates an
|
||||
:ref:`asynchronous context manager <async-context-managers>`.
|
||||
|
||||
This function is a :term:`decorator` that can be used to define a factory
|
||||
function for :keyword:`async with` statement asynchronous context managers,
|
||||
without needing to create a class or separate :meth:`__aenter__` and
|
||||
:meth:`__aexit__` methods. It must be applied to an :term:`asynchronous
|
||||
generator` function.
|
||||
|
||||
A simple example::
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_connection():
|
||||
conn = await acquire_db_connection()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await release_db_connection(conn)
|
||||
|
||||
async def get_all_users():
|
||||
async with get_connection() as conn:
|
||||
return conn.query('SELECT ...')
|
||||
|
||||
.. versionadded:: 3.7
|
||||
|
||||
|
||||
.. function:: closing(thing)
|
||||
|
||||
Return a context manager that closes *thing* upon completion of the block. This
|
||||
|
|
|
@ -2575,6 +2575,8 @@ An example of an asynchronous iterable object::
|
|||
result in a :exc:`RuntimeError`.
|
||||
|
||||
|
||||
.. _async-context-managers:
|
||||
|
||||
Asynchronous Context Managers
|
||||
-----------------------------
|
||||
|
||||
|
|
|
@ -95,6 +95,12 @@ New Modules
|
|||
Improved Modules
|
||||
================
|
||||
|
||||
contextlib
|
||||
----------
|
||||
|
||||
:func:`contextlib.asynccontextmanager` has been added. (Contributed by
|
||||
Jelle Zijlstra in :issue:`29679`.)
|
||||
|
||||
distutils
|
||||
---------
|
||||
|
||||
|
|
|
@ -4,9 +4,9 @@ import sys
|
|||
from collections import deque
|
||||
from functools import wraps
|
||||
|
||||
__all__ = ["contextmanager", "closing", "AbstractContextManager",
|
||||
"ContextDecorator", "ExitStack", "redirect_stdout",
|
||||
"redirect_stderr", "suppress"]
|
||||
__all__ = ["asynccontextmanager", "contextmanager", "closing",
|
||||
"AbstractContextManager", "ContextDecorator", "ExitStack",
|
||||
"redirect_stdout", "redirect_stderr", "suppress"]
|
||||
|
||||
|
||||
class AbstractContextManager(abc.ABC):
|
||||
|
@ -54,8 +54,8 @@ class ContextDecorator(object):
|
|||
return inner
|
||||
|
||||
|
||||
class _GeneratorContextManager(ContextDecorator, AbstractContextManager):
|
||||
"""Helper for @contextmanager decorator."""
|
||||
class _GeneratorContextManagerBase:
|
||||
"""Shared functionality for @contextmanager and @asynccontextmanager."""
|
||||
|
||||
def __init__(self, func, args, kwds):
|
||||
self.gen = func(*args, **kwds)
|
||||
|
@ -71,6 +71,12 @@ class _GeneratorContextManager(ContextDecorator, AbstractContextManager):
|
|||
# for the class instead.
|
||||
# See http://bugs.python.org/issue19404 for more details.
|
||||
|
||||
|
||||
class _GeneratorContextManager(_GeneratorContextManagerBase,
|
||||
AbstractContextManager,
|
||||
ContextDecorator):
|
||||
"""Helper for @contextmanager decorator."""
|
||||
|
||||
def _recreate_cm(self):
|
||||
# _GCM instances are one-shot context managers, so the
|
||||
# CM must be recreated each time a decorated function is
|
||||
|
@ -121,12 +127,61 @@ class _GeneratorContextManager(ContextDecorator, AbstractContextManager):
|
|||
# fixes the impedance mismatch between the throw() protocol
|
||||
# and the __exit__() protocol.
|
||||
#
|
||||
# This cannot use 'except BaseException as exc' (as in the
|
||||
# async implementation) to maintain compatibility with
|
||||
# Python 2, where old-style class exceptions are not caught
|
||||
# by 'except BaseException'.
|
||||
if sys.exc_info()[1] is value:
|
||||
return False
|
||||
raise
|
||||
raise RuntimeError("generator didn't stop after throw()")
|
||||
|
||||
|
||||
class _AsyncGeneratorContextManager(_GeneratorContextManagerBase):
|
||||
"""Helper for @asynccontextmanager."""
|
||||
|
||||
async def __aenter__(self):
|
||||
try:
|
||||
return await self.gen.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise RuntimeError("generator didn't yield") from None
|
||||
|
||||
async def __aexit__(self, typ, value, traceback):
|
||||
if typ is None:
|
||||
try:
|
||||
await self.gen.__anext__()
|
||||
except StopAsyncIteration:
|
||||
return
|
||||
else:
|
||||
raise RuntimeError("generator didn't stop")
|
||||
else:
|
||||
if value is None:
|
||||
value = typ()
|
||||
# See _GeneratorContextManager.__exit__ for comments on subtleties
|
||||
# in this implementation
|
||||
try:
|
||||
await self.gen.athrow(typ, value, traceback)
|
||||
raise RuntimeError("generator didn't stop after throw()")
|
||||
except StopAsyncIteration as exc:
|
||||
return exc is not value
|
||||
except RuntimeError as exc:
|
||||
if exc is value:
|
||||
return False
|
||||
# Avoid suppressing if a StopIteration exception
|
||||
# was passed to throw() and later wrapped into a RuntimeError
|
||||
# (see PEP 479 for sync generators; async generators also
|
||||
# have this behavior). But do this only if the exception wrapped
|
||||
# by the RuntimeError is actully Stop(Async)Iteration (see
|
||||
# issue29692).
|
||||
if isinstance(value, (StopIteration, StopAsyncIteration)):
|
||||
if exc.__cause__ is value:
|
||||
return False
|
||||
raise
|
||||
except BaseException as exc:
|
||||
if exc is not value:
|
||||
raise
|
||||
|
||||
|
||||
def contextmanager(func):
|
||||
"""@contextmanager decorator.
|
||||
|
||||
|
@ -153,7 +208,6 @@ def contextmanager(func):
|
|||
<body>
|
||||
finally:
|
||||
<cleanup>
|
||||
|
||||
"""
|
||||
@wraps(func)
|
||||
def helper(*args, **kwds):
|
||||
|
@ -161,6 +215,39 @@ def contextmanager(func):
|
|||
return helper
|
||||
|
||||
|
||||
def asynccontextmanager(func):
|
||||
"""@asynccontextmanager decorator.
|
||||
|
||||
Typical usage:
|
||||
|
||||
@asynccontextmanager
|
||||
async def some_async_generator(<arguments>):
|
||||
<setup>
|
||||
try:
|
||||
yield <value>
|
||||
finally:
|
||||
<cleanup>
|
||||
|
||||
This makes this:
|
||||
|
||||
async with some_async_generator(<arguments>) as <variable>:
|
||||
<body>
|
||||
|
||||
equivalent to this:
|
||||
|
||||
<setup>
|
||||
try:
|
||||
<variable> = <value>
|
||||
<body>
|
||||
finally:
|
||||
<cleanup>
|
||||
"""
|
||||
@wraps(func)
|
||||
def helper(*args, **kwds):
|
||||
return _AsyncGeneratorContextManager(func, args, kwds)
|
||||
return helper
|
||||
|
||||
|
||||
class closing(AbstractContextManager):
|
||||
"""Context to automatically close something at the end of a block.
|
||||
|
||||
|
|
|
@ -0,0 +1,212 @@
|
|||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import functools
|
||||
from test import support
|
||||
import unittest
|
||||
|
||||
|
||||
def _async_test(func):
|
||||
"""Decorator to turn an async function into a test case."""
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
coro = func(*args, **kwargs)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
return wrapper
|
||||
|
||||
|
||||
class AsyncContextManagerTestCase(unittest.TestCase):
|
||||
|
||||
@_async_test
|
||||
async def test_contextmanager_plain(self):
|
||||
state = []
|
||||
@asynccontextmanager
|
||||
async def woohoo():
|
||||
state.append(1)
|
||||
yield 42
|
||||
state.append(999)
|
||||
async with woohoo() as x:
|
||||
self.assertEqual(state, [1])
|
||||
self.assertEqual(x, 42)
|
||||
state.append(x)
|
||||
self.assertEqual(state, [1, 42, 999])
|
||||
|
||||
@_async_test
|
||||
async def test_contextmanager_finally(self):
|
||||
state = []
|
||||
@asynccontextmanager
|
||||
async def woohoo():
|
||||
state.append(1)
|
||||
try:
|
||||
yield 42
|
||||
finally:
|
||||
state.append(999)
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
async with woohoo() as x:
|
||||
self.assertEqual(state, [1])
|
||||
self.assertEqual(x, 42)
|
||||
state.append(x)
|
||||
raise ZeroDivisionError()
|
||||
self.assertEqual(state, [1, 42, 999])
|
||||
|
||||
@_async_test
|
||||
async def test_contextmanager_no_reraise(self):
|
||||
@asynccontextmanager
|
||||
async def whee():
|
||||
yield
|
||||
ctx = whee()
|
||||
await ctx.__aenter__()
|
||||
# Calling __aexit__ should not result in an exception
|
||||
self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None))
|
||||
|
||||
@_async_test
|
||||
async def test_contextmanager_trap_yield_after_throw(self):
|
||||
@asynccontextmanager
|
||||
async def whoo():
|
||||
try:
|
||||
yield
|
||||
except:
|
||||
yield
|
||||
ctx = whoo()
|
||||
await ctx.__aenter__()
|
||||
with self.assertRaises(RuntimeError):
|
||||
await ctx.__aexit__(TypeError, TypeError('foo'), None)
|
||||
|
||||
@_async_test
|
||||
async def test_contextmanager_trap_no_yield(self):
|
||||
@asynccontextmanager
|
||||
async def whoo():
|
||||
if False:
|
||||
yield
|
||||
ctx = whoo()
|
||||
with self.assertRaises(RuntimeError):
|
||||
await ctx.__aenter__()
|
||||
|
||||
@_async_test
|
||||
async def test_contextmanager_trap_second_yield(self):
|
||||
@asynccontextmanager
|
||||
async def whoo():
|
||||
yield
|
||||
yield
|
||||
ctx = whoo()
|
||||
await ctx.__aenter__()
|
||||
with self.assertRaises(RuntimeError):
|
||||
await ctx.__aexit__(None, None, None)
|
||||
|
||||
@_async_test
|
||||
async def test_contextmanager_non_normalised(self):
|
||||
@asynccontextmanager
|
||||
async def whoo():
|
||||
try:
|
||||
yield
|
||||
except RuntimeError:
|
||||
raise SyntaxError
|
||||
|
||||
ctx = whoo()
|
||||
await ctx.__aenter__()
|
||||
with self.assertRaises(SyntaxError):
|
||||
await ctx.__aexit__(RuntimeError, None, None)
|
||||
|
||||
@_async_test
|
||||
async def test_contextmanager_except(self):
|
||||
state = []
|
||||
@asynccontextmanager
|
||||
async def woohoo():
|
||||
state.append(1)
|
||||
try:
|
||||
yield 42
|
||||
except ZeroDivisionError as e:
|
||||
state.append(e.args[0])
|
||||
self.assertEqual(state, [1, 42, 999])
|
||||
async with woohoo() as x:
|
||||
self.assertEqual(state, [1])
|
||||
self.assertEqual(x, 42)
|
||||
state.append(x)
|
||||
raise ZeroDivisionError(999)
|
||||
self.assertEqual(state, [1, 42, 999])
|
||||
|
||||
@_async_test
|
||||
async def test_contextmanager_except_stopiter(self):
|
||||
@asynccontextmanager
|
||||
async def woohoo():
|
||||
yield
|
||||
|
||||
for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')):
|
||||
with self.subTest(type=type(stop_exc)):
|
||||
try:
|
||||
async with woohoo():
|
||||
raise stop_exc
|
||||
except Exception as ex:
|
||||
self.assertIs(ex, stop_exc)
|
||||
else:
|
||||
self.fail(f'{stop_exc} was suppressed')
|
||||
|
||||
@_async_test
|
||||
async def test_contextmanager_wrap_runtimeerror(self):
|
||||
@asynccontextmanager
|
||||
async def woohoo():
|
||||
try:
|
||||
yield
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f'caught {exc}') from exc
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
async with woohoo():
|
||||
1 / 0
|
||||
|
||||
# If the context manager wrapped StopAsyncIteration in a RuntimeError,
|
||||
# we also unwrap it, because we can't tell whether the wrapping was
|
||||
# done by the generator machinery or by the generator itself.
|
||||
with self.assertRaises(StopAsyncIteration):
|
||||
async with woohoo():
|
||||
raise StopAsyncIteration
|
||||
|
||||
def _create_contextmanager_attribs(self):
|
||||
def attribs(**kw):
|
||||
def decorate(func):
|
||||
for k,v in kw.items():
|
||||
setattr(func,k,v)
|
||||
return func
|
||||
return decorate
|
||||
@asynccontextmanager
|
||||
@attribs(foo='bar')
|
||||
async def baz(spam):
|
||||
"""Whee!"""
|
||||
yield
|
||||
return baz
|
||||
|
||||
def test_contextmanager_attribs(self):
|
||||
baz = self._create_contextmanager_attribs()
|
||||
self.assertEqual(baz.__name__,'baz')
|
||||
self.assertEqual(baz.foo, 'bar')
|
||||
|
||||
@support.requires_docstrings
|
||||
def test_contextmanager_doc_attrib(self):
|
||||
baz = self._create_contextmanager_attribs()
|
||||
self.assertEqual(baz.__doc__, "Whee!")
|
||||
|
||||
@support.requires_docstrings
|
||||
@_async_test
|
||||
async def test_instance_docstring_given_cm_docstring(self):
|
||||
baz = self._create_contextmanager_attribs()(None)
|
||||
self.assertEqual(baz.__doc__, "Whee!")
|
||||
async with baz:
|
||||
pass # suppress warning
|
||||
|
||||
@_async_test
|
||||
async def test_keywords(self):
|
||||
# Ensure no keyword arguments are inhibited
|
||||
@asynccontextmanager
|
||||
async def woohoo(self, func, args, kwds):
|
||||
yield (self, func, args, kwds)
|
||||
async with woohoo(self=11, func=22, args=33, kwds=44) as target:
|
||||
self.assertEqual(target, (11, 22, 33, 44))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue