bpo-38857: AsyncMock fix for awaitable values and StopIteration fix [3.8] (GH-17269)
This commit is contained in:
parent
e5d1f734db
commit
046442d02b
|
@ -873,7 +873,7 @@ object::
|
|||
exception,
|
||||
- if ``side_effect`` is an iterable, the async function will return the
|
||||
next value of the iterable, however, if the sequence of result is
|
||||
exhausted, ``StopIteration`` is raised immediately,
|
||||
exhausted, ``StopAsyncIteration`` is raised immediately,
|
||||
- if ``side_effect`` is not defined, the async function will return the
|
||||
value defined by ``return_value``, hence, by default, the async function
|
||||
returns a new :class:`AsyncMock` object.
|
||||
|
|
|
@ -1139,8 +1139,8 @@ class CallableMixin(Base):
|
|||
_new_parent = _new_parent._mock_new_parent
|
||||
|
||||
def _execute_mock_call(self, /, *args, **kwargs):
|
||||
# seperate from _increment_mock_call so that awaited functions are
|
||||
# executed seperately from their call
|
||||
# separate from _increment_mock_call so that awaited functions are
|
||||
# executed separately from their call, also AsyncMock overrides this method
|
||||
|
||||
effect = self.side_effect
|
||||
if effect is not None:
|
||||
|
@ -2136,29 +2136,45 @@ class AsyncMockMixin(Base):
|
|||
code_mock.co_flags = inspect.CO_COROUTINE
|
||||
self.__dict__['__code__'] = code_mock
|
||||
|
||||
async def _mock_call(self, /, *args, **kwargs):
|
||||
try:
|
||||
result = super()._mock_call(*args, **kwargs)
|
||||
except (BaseException, StopIteration) as e:
|
||||
side_effect = self.side_effect
|
||||
if side_effect is not None and not callable(side_effect):
|
||||
raise
|
||||
return await _raise(e)
|
||||
async def _execute_mock_call(self, /, *args, **kwargs):
|
||||
# This is nearly just like super(), except for sepcial handling
|
||||
# of coroutines
|
||||
|
||||
_call = self.call_args
|
||||
self.await_count += 1
|
||||
self.await_args = _call
|
||||
self.await_args_list.append(_call)
|
||||
|
||||
async def proxy():
|
||||
try:
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
else:
|
||||
return result
|
||||
finally:
|
||||
self.await_count += 1
|
||||
self.await_args = _call
|
||||
self.await_args_list.append(_call)
|
||||
effect = self.side_effect
|
||||
if effect is not None:
|
||||
if _is_exception(effect):
|
||||
raise effect
|
||||
elif not _callable(effect):
|
||||
try:
|
||||
result = next(effect)
|
||||
except StopIteration:
|
||||
# It is impossible to propogate a StopIteration
|
||||
# through coroutines because of PEP 479
|
||||
raise StopAsyncIteration
|
||||
if _is_exception(result):
|
||||
raise result
|
||||
elif asyncio.iscoroutinefunction(effect):
|
||||
result = await effect(*args, **kwargs)
|
||||
else:
|
||||
result = effect(*args, **kwargs)
|
||||
|
||||
return await proxy()
|
||||
if result is not DEFAULT:
|
||||
return result
|
||||
|
||||
if self._mock_return_value is not DEFAULT:
|
||||
return self.return_value
|
||||
|
||||
if self._mock_wraps is not None:
|
||||
if asyncio.iscoroutinefunction(self._mock_wraps):
|
||||
return await self._mock_wraps(*args, **kwargs)
|
||||
return self._mock_wraps(*args, **kwargs)
|
||||
|
||||
return self.return_value
|
||||
|
||||
def assert_awaited(self):
|
||||
"""
|
||||
|
@ -2864,10 +2880,6 @@ def seal(mock):
|
|||
seal(m)
|
||||
|
||||
|
||||
async def _raise(exception):
|
||||
raise exception
|
||||
|
||||
|
||||
class _AsyncIterator:
|
||||
"""
|
||||
Wraps an iterator in an asynchronous iterator.
|
||||
|
|
|
@ -358,42 +358,84 @@ class AsyncSpecSetTest(unittest.TestCase):
|
|||
self.assertIsInstance(cm, MagicMock)
|
||||
|
||||
|
||||
class AsyncArguments(unittest.TestCase):
|
||||
def test_add_return_value(self):
|
||||
class AsyncArguments(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_add_return_value(self):
|
||||
async def addition(self, var):
|
||||
return var + 1
|
||||
|
||||
mock = AsyncMock(addition, return_value=10)
|
||||
output = asyncio.run(mock(5))
|
||||
output = await mock(5)
|
||||
|
||||
self.assertEqual(output, 10)
|
||||
|
||||
def test_add_side_effect_exception(self):
|
||||
async def test_add_side_effect_exception(self):
|
||||
async def addition(var):
|
||||
return var + 1
|
||||
mock = AsyncMock(addition, side_effect=Exception('err'))
|
||||
with self.assertRaises(Exception):
|
||||
asyncio.run(mock(5))
|
||||
await mock(5)
|
||||
|
||||
def test_add_side_effect_function(self):
|
||||
async def test_add_side_effect_function(self):
|
||||
async def addition(var):
|
||||
return var + 1
|
||||
mock = AsyncMock(side_effect=addition)
|
||||
result = asyncio.run(mock(5))
|
||||
result = await mock(5)
|
||||
self.assertEqual(result, 6)
|
||||
|
||||
def test_add_side_effect_iterable(self):
|
||||
async def test_add_side_effect_iterable(self):
|
||||
vals = [1, 2, 3]
|
||||
mock = AsyncMock(side_effect=vals)
|
||||
for item in vals:
|
||||
self.assertEqual(item, asyncio.run(mock()))
|
||||
self.assertEqual(item, await mock())
|
||||
|
||||
with self.assertRaises(RuntimeError) as e:
|
||||
asyncio.run(mock())
|
||||
self.assertEqual(
|
||||
e.exception,
|
||||
RuntimeError('coroutine raised StopIteration')
|
||||
)
|
||||
with self.assertRaises(StopAsyncIteration) as e:
|
||||
await mock()
|
||||
|
||||
async def test_return_value_AsyncMock(self):
|
||||
value = AsyncMock(return_value=10)
|
||||
mock = AsyncMock(return_value=value)
|
||||
result = await mock()
|
||||
self.assertIs(result, value)
|
||||
|
||||
async def test_return_value_awaitable(self):
|
||||
fut = asyncio.Future()
|
||||
fut.set_result(None)
|
||||
mock = AsyncMock(return_value=fut)
|
||||
result = await mock()
|
||||
self.assertIsInstance(result, asyncio.Future)
|
||||
|
||||
async def test_side_effect_awaitable_values(self):
|
||||
fut = asyncio.Future()
|
||||
fut.set_result(None)
|
||||
|
||||
mock = AsyncMock(side_effect=[fut])
|
||||
result = await mock()
|
||||
self.assertIsInstance(result, asyncio.Future)
|
||||
|
||||
with self.assertRaises(StopAsyncIteration):
|
||||
await mock()
|
||||
|
||||
async def test_side_effect_is_AsyncMock(self):
|
||||
effect = AsyncMock(return_value=10)
|
||||
mock = AsyncMock(side_effect=effect)
|
||||
|
||||
result = await mock()
|
||||
self.assertEqual(result, 10)
|
||||
|
||||
async def test_wraps_coroutine(self):
|
||||
value = asyncio.Future()
|
||||
|
||||
ran = False
|
||||
async def inner():
|
||||
nonlocal ran
|
||||
ran = True
|
||||
return value
|
||||
|
||||
mock = AsyncMock(wraps=inner)
|
||||
result = await mock()
|
||||
self.assertEqual(result, value)
|
||||
mock.assert_awaited()
|
||||
self.assertTrue(ran)
|
||||
|
||||
class AsyncMagicMethods(unittest.TestCase):
|
||||
def test_async_magic_methods_return_async_mocks(self):
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
AsyncMock fix for return values that are awaitable types. This also covers
|
||||
side_effect iterable values that happend to be awaitable, and wraps
|
||||
callables that return an awaitable type. Before these awaitables were being
|
||||
awaited instead of being returned as is.
|
|
@ -0,0 +1,3 @@
|
|||
AsyncMock now returns StopAsyncIteration on the exaustion of a side_effects
|
||||
iterable. Since PEP-479 its Impossible to raise a StopIteration exception
|
||||
from a coroutine.
|
Loading…
Reference in New Issue