213 lines
6.4 KiB
Python
213 lines
6.4 KiB
Python
|
import asyncio
|
||
|
from contextlib import asynccontextmanager
|
||
|
import functools
|
||
|
from test import support
|
||
|
import unittest
|
||
|
|
||
|
|
||
|
def _async_test(func):
|
||
|
"""Decorator to turn an async function into a test case."""
|
||
|
@functools.wraps(func)
|
||
|
def wrapper(*args, **kwargs):
|
||
|
coro = func(*args, **kwargs)
|
||
|
loop = asyncio.new_event_loop()
|
||
|
asyncio.set_event_loop(loop)
|
||
|
try:
|
||
|
return loop.run_until_complete(coro)
|
||
|
finally:
|
||
|
loop.close()
|
||
|
asyncio.set_event_loop(None)
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
class AsyncContextManagerTestCase(unittest.TestCase):
|
||
|
|
||
|
@_async_test
|
||
|
async def test_contextmanager_plain(self):
|
||
|
state = []
|
||
|
@asynccontextmanager
|
||
|
async def woohoo():
|
||
|
state.append(1)
|
||
|
yield 42
|
||
|
state.append(999)
|
||
|
async with woohoo() as x:
|
||
|
self.assertEqual(state, [1])
|
||
|
self.assertEqual(x, 42)
|
||
|
state.append(x)
|
||
|
self.assertEqual(state, [1, 42, 999])
|
||
|
|
||
|
@_async_test
|
||
|
async def test_contextmanager_finally(self):
|
||
|
state = []
|
||
|
@asynccontextmanager
|
||
|
async def woohoo():
|
||
|
state.append(1)
|
||
|
try:
|
||
|
yield 42
|
||
|
finally:
|
||
|
state.append(999)
|
||
|
with self.assertRaises(ZeroDivisionError):
|
||
|
async with woohoo() as x:
|
||
|
self.assertEqual(state, [1])
|
||
|
self.assertEqual(x, 42)
|
||
|
state.append(x)
|
||
|
raise ZeroDivisionError()
|
||
|
self.assertEqual(state, [1, 42, 999])
|
||
|
|
||
|
@_async_test
|
||
|
async def test_contextmanager_no_reraise(self):
|
||
|
@asynccontextmanager
|
||
|
async def whee():
|
||
|
yield
|
||
|
ctx = whee()
|
||
|
await ctx.__aenter__()
|
||
|
# Calling __aexit__ should not result in an exception
|
||
|
self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None))
|
||
|
|
||
|
@_async_test
|
||
|
async def test_contextmanager_trap_yield_after_throw(self):
|
||
|
@asynccontextmanager
|
||
|
async def whoo():
|
||
|
try:
|
||
|
yield
|
||
|
except:
|
||
|
yield
|
||
|
ctx = whoo()
|
||
|
await ctx.__aenter__()
|
||
|
with self.assertRaises(RuntimeError):
|
||
|
await ctx.__aexit__(TypeError, TypeError('foo'), None)
|
||
|
|
||
|
@_async_test
|
||
|
async def test_contextmanager_trap_no_yield(self):
|
||
|
@asynccontextmanager
|
||
|
async def whoo():
|
||
|
if False:
|
||
|
yield
|
||
|
ctx = whoo()
|
||
|
with self.assertRaises(RuntimeError):
|
||
|
await ctx.__aenter__()
|
||
|
|
||
|
@_async_test
|
||
|
async def test_contextmanager_trap_second_yield(self):
|
||
|
@asynccontextmanager
|
||
|
async def whoo():
|
||
|
yield
|
||
|
yield
|
||
|
ctx = whoo()
|
||
|
await ctx.__aenter__()
|
||
|
with self.assertRaises(RuntimeError):
|
||
|
await ctx.__aexit__(None, None, None)
|
||
|
|
||
|
@_async_test
|
||
|
async def test_contextmanager_non_normalised(self):
|
||
|
@asynccontextmanager
|
||
|
async def whoo():
|
||
|
try:
|
||
|
yield
|
||
|
except RuntimeError:
|
||
|
raise SyntaxError
|
||
|
|
||
|
ctx = whoo()
|
||
|
await ctx.__aenter__()
|
||
|
with self.assertRaises(SyntaxError):
|
||
|
await ctx.__aexit__(RuntimeError, None, None)
|
||
|
|
||
|
@_async_test
|
||
|
async def test_contextmanager_except(self):
|
||
|
state = []
|
||
|
@asynccontextmanager
|
||
|
async def woohoo():
|
||
|
state.append(1)
|
||
|
try:
|
||
|
yield 42
|
||
|
except ZeroDivisionError as e:
|
||
|
state.append(e.args[0])
|
||
|
self.assertEqual(state, [1, 42, 999])
|
||
|
async with woohoo() as x:
|
||
|
self.assertEqual(state, [1])
|
||
|
self.assertEqual(x, 42)
|
||
|
state.append(x)
|
||
|
raise ZeroDivisionError(999)
|
||
|
self.assertEqual(state, [1, 42, 999])
|
||
|
|
||
|
@_async_test
|
||
|
async def test_contextmanager_except_stopiter(self):
|
||
|
@asynccontextmanager
|
||
|
async def woohoo():
|
||
|
yield
|
||
|
|
||
|
for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')):
|
||
|
with self.subTest(type=type(stop_exc)):
|
||
|
try:
|
||
|
async with woohoo():
|
||
|
raise stop_exc
|
||
|
except Exception as ex:
|
||
|
self.assertIs(ex, stop_exc)
|
||
|
else:
|
||
|
self.fail(f'{stop_exc} was suppressed')
|
||
|
|
||
|
@_async_test
|
||
|
async def test_contextmanager_wrap_runtimeerror(self):
|
||
|
@asynccontextmanager
|
||
|
async def woohoo():
|
||
|
try:
|
||
|
yield
|
||
|
except Exception as exc:
|
||
|
raise RuntimeError(f'caught {exc}') from exc
|
||
|
|
||
|
with self.assertRaises(RuntimeError):
|
||
|
async with woohoo():
|
||
|
1 / 0
|
||
|
|
||
|
# If the context manager wrapped StopAsyncIteration in a RuntimeError,
|
||
|
# we also unwrap it, because we can't tell whether the wrapping was
|
||
|
# done by the generator machinery or by the generator itself.
|
||
|
with self.assertRaises(StopAsyncIteration):
|
||
|
async with woohoo():
|
||
|
raise StopAsyncIteration
|
||
|
|
||
|
def _create_contextmanager_attribs(self):
|
||
|
def attribs(**kw):
|
||
|
def decorate(func):
|
||
|
for k,v in kw.items():
|
||
|
setattr(func,k,v)
|
||
|
return func
|
||
|
return decorate
|
||
|
@asynccontextmanager
|
||
|
@attribs(foo='bar')
|
||
|
async def baz(spam):
|
||
|
"""Whee!"""
|
||
|
yield
|
||
|
return baz
|
||
|
|
||
|
def test_contextmanager_attribs(self):
|
||
|
baz = self._create_contextmanager_attribs()
|
||
|
self.assertEqual(baz.__name__,'baz')
|
||
|
self.assertEqual(baz.foo, 'bar')
|
||
|
|
||
|
@support.requires_docstrings
|
||
|
def test_contextmanager_doc_attrib(self):
|
||
|
baz = self._create_contextmanager_attribs()
|
||
|
self.assertEqual(baz.__doc__, "Whee!")
|
||
|
|
||
|
@support.requires_docstrings
|
||
|
@_async_test
|
||
|
async def test_instance_docstring_given_cm_docstring(self):
|
||
|
baz = self._create_contextmanager_attribs()(None)
|
||
|
self.assertEqual(baz.__doc__, "Whee!")
|
||
|
async with baz:
|
||
|
pass # suppress warning
|
||
|
|
||
|
@_async_test
|
||
|
async def test_keywords(self):
|
||
|
# Ensure no keyword arguments are inhibited
|
||
|
@asynccontextmanager
|
||
|
async def woohoo(self, func, args, kwds):
|
||
|
yield (self, func, args, kwds)
|
||
|
async with woohoo(self=11, func=22, args=33, kwds=44) as target:
|
||
|
self.assertEqual(target, (11, 22, 33, 44))
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|