bpo-26467: Adds AsyncMock for asyncio Mock library support (GH-9296)
This commit is contained in:
parent
0f72147ce2
commit
77b3b7701a
|
@ -201,9 +201,11 @@ The Mock Class
|
|||
|
||||
.. testsetup::
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import unittest
|
||||
from unittest.mock import sentinel, DEFAULT, ANY
|
||||
from unittest.mock import patch, call, Mock, MagicMock, PropertyMock
|
||||
from unittest.mock import patch, call, Mock, MagicMock, PropertyMock, AsyncMock
|
||||
from unittest.mock import mock_open
|
||||
|
||||
:class:`Mock` is a flexible mock object intended to replace the use of stubs and
|
||||
|
@ -851,6 +853,217 @@ object::
|
|||
>>> p.assert_called_once_with()
|
||||
|
||||
|
||||
.. class:: AsyncMock(spec=None, side_effect=None, return_value=DEFAULT, wraps=None, name=None, spec_set=None, unsafe=False, **kwargs)
|
||||
|
||||
An asynchronous version of :class:`Mock`. The :class:`AsyncMock` object will
|
||||
behave so the object is recognized as an async function, and the result of a
|
||||
call is an awaitable.
|
||||
|
||||
>>> mock = AsyncMock()
|
||||
>>> asyncio.iscoroutinefunction(mock)
|
||||
True
|
||||
>>> inspect.isawaitable(mock())
|
||||
True
|
||||
|
||||
The result of ``mock()`` is an async function which will have the outcome
|
||||
of ``side_effect`` or ``return_value``:
|
||||
|
||||
- if ``side_effect`` is a function, the async function will return the
|
||||
result of that function,
|
||||
- if ``side_effect`` is an exception, the async function will raise the
|
||||
exception,
|
||||
- if ``side_effect`` is an iterable, the async function will return the
|
||||
next value of the iterable, however, if the sequence of result is
|
||||
exhausted, ``StopIteration`` is raised immediately,
|
||||
- if ``side_effect`` is not defined, the async function will return the
|
||||
value defined by ``return_value``, hence, by default, the async function
|
||||
returns a new :class:`AsyncMock` object.
|
||||
|
||||
|
||||
Setting the *spec* of a :class:`Mock` or :class:`MagicMock` to an async function
|
||||
will result in a coroutine object being returned after calling.
|
||||
|
||||
>>> async def async_func(): pass
|
||||
...
|
||||
>>> mock = MagicMock(async_func)
|
||||
>>> mock
|
||||
<MagicMock spec='function' id='...'>
|
||||
>>> mock()
|
||||
<coroutine object AsyncMockMixin._mock_call at ...>
|
||||
|
||||
.. method:: assert_awaited()
|
||||
|
||||
Assert that the mock was awaited at least once.
|
||||
|
||||
>>> mock = AsyncMock()
|
||||
>>> async def main():
|
||||
... await mock()
|
||||
...
|
||||
>>> asyncio.run(main())
|
||||
>>> mock.assert_awaited()
|
||||
>>> mock_2 = AsyncMock()
|
||||
>>> mock_2.assert_awaited()
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError: Expected mock to have been awaited.
|
||||
|
||||
.. method:: assert_awaited_once()
|
||||
|
||||
Assert that the mock was awaited exactly once.
|
||||
|
||||
>>> mock = AsyncMock()
|
||||
>>> async def main():
|
||||
... await mock()
|
||||
...
|
||||
>>> asyncio.run(main())
|
||||
>>> mock.assert_awaited_once()
|
||||
>>> asyncio.run(main())
|
||||
>>> mock.method.assert_awaited_once()
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError: Expected mock to have been awaited once. Awaited 2 times.
|
||||
|
||||
.. method:: assert_awaited_with(*args, **kwargs)
|
||||
|
||||
Assert that the last await was with the specified arguments.
|
||||
|
||||
>>> mock = AsyncMock()
|
||||
>>> async def main(*args, **kwargs):
|
||||
... await mock(*args, **kwargs)
|
||||
...
|
||||
>>> asyncio.run(main('foo', bar='bar'))
|
||||
>>> mock.assert_awaited_with('foo', bar='bar')
|
||||
>>> mock.assert_awaited_with('other')
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError: expected call not found.
|
||||
Expected: mock('other')
|
||||
Actual: mock('foo', bar='bar')
|
||||
|
||||
.. method:: assert_awaited_once_with(*args, **kwargs)
|
||||
|
||||
Assert that the mock was awaited exactly once and with the specified
|
||||
arguments.
|
||||
|
||||
>>> mock = AsyncMock()
|
||||
>>> async def main(*args, **kwargs):
|
||||
... await mock(*args, **kwargs)
|
||||
...
|
||||
>>> asyncio.run(main('foo', bar='bar'))
|
||||
>>> mock.assert_awaited_once_with('foo', bar='bar')
|
||||
>>> asyncio.run(main('foo', bar='bar'))
|
||||
>>> mock.assert_awaited_once_with('foo', bar='bar')
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError: Expected mock to have been awaited once. Awaited 2 times.
|
||||
|
||||
.. method:: assert_any_await(*args, **kwargs)
|
||||
|
||||
Assert the mock has ever been awaited with the specified arguments.
|
||||
|
||||
>>> mock = AsyncMock()
|
||||
>>> async def main(*args, **kwargs):
|
||||
... await mock(*args, **kwargs)
|
||||
...
|
||||
>>> asyncio.run(main('foo', bar='bar'))
|
||||
>>> asyncio.run(main('hello'))
|
||||
>>> mock.assert_any_await('foo', bar='bar')
|
||||
>>> mock.assert_any_await('other')
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError: mock('other') await not found
|
||||
|
||||
.. method:: assert_has_awaits(calls, any_order=False)
|
||||
|
||||
Assert the mock has been awaited with the specified calls.
|
||||
The :attr:`await_args_list` list is checked for the awaits.
|
||||
|
||||
If *any_order* is False (the default) then the awaits must be
|
||||
sequential. There can be extra calls before or after the
|
||||
specified awaits.
|
||||
|
||||
If *any_order* is True then the awaits can be in any order, but
|
||||
they must all appear in :attr:`await_args_list`.
|
||||
|
||||
>>> mock = AsyncMock()
|
||||
>>> async def main(*args, **kwargs):
|
||||
... await mock(*args, **kwargs)
|
||||
...
|
||||
>>> calls = [call("foo"), call("bar")]
|
||||
>>> mock.assert_has_calls(calls)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError: Calls not found.
|
||||
Expected: [call('foo'), call('bar')]
|
||||
>>> asyncio.run(main('foo'))
|
||||
>>> asyncio.run(main('bar'))
|
||||
>>> mock.assert_has_calls(calls)
|
||||
|
||||
.. method:: assert_not_awaited()
|
||||
|
||||
Assert that the mock was never awaited.
|
||||
|
||||
>>> mock = AsyncMock()
|
||||
>>> mock.assert_not_awaited()
|
||||
|
||||
.. method:: reset_mock(*args, **kwargs)
|
||||
|
||||
See :func:`Mock.reset_mock`. Also sets :attr:`await_count` to 0,
|
||||
:attr:`await_args` to None, and clears the :attr:`await_args_list`.
|
||||
|
||||
.. attribute:: await_count
|
||||
|
||||
An integer keeping track of how many times the mock object has been awaited.
|
||||
|
||||
>>> mock = AsyncMock()
|
||||
>>> async def main():
|
||||
... await mock()
|
||||
...
|
||||
>>> asyncio.run(main())
|
||||
>>> mock.await_count
|
||||
1
|
||||
>>> asyncio.run(main())
|
||||
>>> mock.await_count
|
||||
2
|
||||
|
||||
.. attribute:: await_args
|
||||
|
||||
This is either ``None`` (if the mock hasn’t been awaited), or the arguments that
|
||||
the mock was last awaited with. Functions the same as :attr:`Mock.call_args`.
|
||||
|
||||
>>> mock = AsyncMock()
|
||||
>>> async def main(*args):
|
||||
... await mock(*args)
|
||||
...
|
||||
>>> mock.await_args
|
||||
>>> asyncio.run(main('foo'))
|
||||
>>> mock.await_args
|
||||
call('foo')
|
||||
>>> asyncio.run(main('bar'))
|
||||
>>> mock.await_args
|
||||
call('bar')
|
||||
|
||||
|
||||
.. attribute:: await_args_list
|
||||
|
||||
This is a list of all the awaits made to the mock object in sequence (so the
|
||||
length of the list is the number of times it has been awaited). Before any
|
||||
awaits have been made it is an empty list.
|
||||
|
||||
>>> mock = AsyncMock()
|
||||
>>> async def main(*args):
|
||||
... await mock(*args)
|
||||
...
|
||||
>>> mock.await_args_list
|
||||
[]
|
||||
>>> asyncio.run(main('foo'))
|
||||
>>> mock.await_args_list
|
||||
[call('foo')]
|
||||
>>> asyncio.run(main('bar'))
|
||||
>>> mock.await_args_list
|
||||
[call('foo'), call('bar')]
|
||||
|
||||
|
||||
Calling
|
||||
~~~~~~~
|
||||
|
||||
|
|
|
@ -538,6 +538,10 @@ unicodedata
|
|||
unittest
|
||||
--------
|
||||
|
||||
* XXX Added :class:`AsyncMock` to support an asynchronous version of :class:`Mock`.
|
||||
Appropriate new assert functions for testing have been added as well.
|
||||
(Contributed by Lisa Roach in :issue:`26467`).
|
||||
|
||||
* Added :func:`~unittest.addModuleCleanup()` and
|
||||
:meth:`~unittest.TestCase.addClassCleanup()` to unittest to support
|
||||
cleanups for :func:`~unittest.setUpModule()` and
|
||||
|
|
|
@ -13,6 +13,7 @@ __all__ = (
|
|||
'ANY',
|
||||
'call',
|
||||
'create_autospec',
|
||||
'AsyncMock',
|
||||
'FILTER_DIR',
|
||||
'NonCallableMock',
|
||||
'NonCallableMagicMock',
|
||||
|
@ -24,13 +25,13 @@ __all__ = (
|
|||
|
||||
__version__ = '1.0'
|
||||
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import inspect
|
||||
import pprint
|
||||
import sys
|
||||
import builtins
|
||||
from types import ModuleType, MethodType
|
||||
from types import CodeType, ModuleType, MethodType
|
||||
from unittest.util import safe_repr
|
||||
from functools import wraps, partial
|
||||
|
||||
|
@ -43,6 +44,13 @@ FILTER_DIR = True
|
|||
# Without this, the __class__ properties wouldn't be set correctly
|
||||
_safe_super = super
|
||||
|
||||
def _is_async_obj(obj):
|
||||
if getattr(obj, '__code__', None):
|
||||
return asyncio.iscoroutinefunction(obj) or inspect.isawaitable(obj)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def _is_instance_mock(obj):
|
||||
# can't use isinstance on Mock objects because they override __class__
|
||||
# The base class for all mocks is NonCallableMock
|
||||
|
@ -355,7 +363,20 @@ class NonCallableMock(Base):
|
|||
# every instance has its own class
|
||||
# so we can create magic methods on the
|
||||
# class without stomping on other mocks
|
||||
new = type(cls.__name__, (cls,), {'__doc__': cls.__doc__})
|
||||
bases = (cls,)
|
||||
if not issubclass(cls, AsyncMock):
|
||||
# Check if spec is an async object or function
|
||||
sig = inspect.signature(NonCallableMock.__init__)
|
||||
bound_args = sig.bind_partial(cls, *args, **kw).arguments
|
||||
spec_arg = [
|
||||
arg for arg in bound_args.keys()
|
||||
if arg.startswith('spec')
|
||||
]
|
||||
if spec_arg:
|
||||
# what if spec_set is different than spec?
|
||||
if _is_async_obj(bound_args[spec_arg[0]]):
|
||||
bases = (AsyncMockMixin, cls,)
|
||||
new = type(cls.__name__, bases, {'__doc__': cls.__doc__})
|
||||
instance = object.__new__(new)
|
||||
return instance
|
||||
|
||||
|
@ -431,6 +452,11 @@ class NonCallableMock(Base):
|
|||
_eat_self=False):
|
||||
_spec_class = None
|
||||
_spec_signature = None
|
||||
_spec_asyncs = []
|
||||
|
||||
for attr in dir(spec):
|
||||
if asyncio.iscoroutinefunction(getattr(spec, attr, None)):
|
||||
_spec_asyncs.append(attr)
|
||||
|
||||
if spec is not None and not _is_list(spec):
|
||||
if isinstance(spec, type):
|
||||
|
@ -448,7 +474,7 @@ class NonCallableMock(Base):
|
|||
__dict__['_spec_set'] = spec_set
|
||||
__dict__['_spec_signature'] = _spec_signature
|
||||
__dict__['_mock_methods'] = spec
|
||||
|
||||
__dict__['_spec_asyncs'] = _spec_asyncs
|
||||
|
||||
def __get_return_value(self):
|
||||
ret = self._mock_return_value
|
||||
|
@ -886,7 +912,15 @@ class NonCallableMock(Base):
|
|||
|
||||
For non-callable mocks the callable variant will be used (rather than
|
||||
any custom subclass)."""
|
||||
_new_name = kw.get("_new_name")
|
||||
if _new_name in self.__dict__['_spec_asyncs']:
|
||||
return AsyncMock(**kw)
|
||||
|
||||
_type = type(self)
|
||||
if issubclass(_type, MagicMock) and _new_name in _async_method_magics:
|
||||
klass = AsyncMock
|
||||
if issubclass(_type, AsyncMockMixin):
|
||||
klass = MagicMock
|
||||
if not issubclass(_type, CallableMixin):
|
||||
if issubclass(_type, NonCallableMagicMock):
|
||||
klass = MagicMock
|
||||
|
@ -932,14 +966,12 @@ def _try_iter(obj):
|
|||
return obj
|
||||
|
||||
|
||||
|
||||
class CallableMixin(Base):
|
||||
|
||||
def __init__(self, spec=None, side_effect=None, return_value=DEFAULT,
|
||||
wraps=None, name=None, spec_set=None, parent=None,
|
||||
_spec_state=None, _new_name='', _new_parent=None, **kwargs):
|
||||
self.__dict__['_mock_return_value'] = return_value
|
||||
|
||||
_safe_super(CallableMixin, self).__init__(
|
||||
spec, wraps, name, spec_set, parent,
|
||||
_spec_state, _new_name, _new_parent, **kwargs
|
||||
|
@ -1081,7 +1113,6 @@ class Mock(CallableMixin, NonCallableMock):
|
|||
"""
|
||||
|
||||
|
||||
|
||||
def _dot_lookup(thing, comp, import_path):
|
||||
try:
|
||||
return getattr(thing, comp)
|
||||
|
@ -1279,8 +1310,10 @@ class _patch(object):
|
|||
if isinstance(original, type):
|
||||
# If we're patching out a class and there is a spec
|
||||
inherit = True
|
||||
|
||||
Klass = MagicMock
|
||||
if spec is None and _is_async_obj(original):
|
||||
Klass = AsyncMock
|
||||
else:
|
||||
Klass = MagicMock
|
||||
_kwargs = {}
|
||||
if new_callable is not None:
|
||||
Klass = new_callable
|
||||
|
@ -1292,7 +1325,9 @@ class _patch(object):
|
|||
not_callable = '__call__' not in this_spec
|
||||
else:
|
||||
not_callable = not callable(this_spec)
|
||||
if not_callable:
|
||||
if _is_async_obj(this_spec):
|
||||
Klass = AsyncMock
|
||||
elif not_callable:
|
||||
Klass = NonCallableMagicMock
|
||||
|
||||
if spec is not None:
|
||||
|
@ -1733,7 +1768,7 @@ _non_defaults = {
|
|||
'__reduce__', '__reduce_ex__', '__getinitargs__', '__getnewargs__',
|
||||
'__getstate__', '__setstate__', '__getformat__', '__setformat__',
|
||||
'__repr__', '__dir__', '__subclasses__', '__format__',
|
||||
'__getnewargs_ex__',
|
||||
'__getnewargs_ex__', '__aenter__', '__aexit__', '__anext__', '__aiter__',
|
||||
}
|
||||
|
||||
|
||||
|
@ -1750,6 +1785,11 @@ _magics = {
|
|||
' '.join([magic_methods, numerics, inplace, right]).split()
|
||||
}
|
||||
|
||||
# Magic methods used for async `with` statements
|
||||
_async_method_magics = {"__aenter__", "__aexit__", "__anext__"}
|
||||
# `__aiter__` is a plain function but used with async calls
|
||||
_async_magics = _async_method_magics | {"__aiter__"}
|
||||
|
||||
_all_magics = _magics | _non_defaults
|
||||
|
||||
_unsupported_magics = {
|
||||
|
@ -1779,6 +1819,7 @@ _return_values = {
|
|||
'__float__': 1.0,
|
||||
'__bool__': True,
|
||||
'__index__': 1,
|
||||
'__aexit__': False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -1811,10 +1852,19 @@ def _get_iter(self):
|
|||
return iter(ret_val)
|
||||
return __iter__
|
||||
|
||||
def _get_async_iter(self):
|
||||
def __aiter__():
|
||||
ret_val = self.__aiter__._mock_return_value
|
||||
if ret_val is DEFAULT:
|
||||
return _AsyncIterator(iter([]))
|
||||
return _AsyncIterator(iter(ret_val))
|
||||
return __aiter__
|
||||
|
||||
_side_effect_methods = {
|
||||
'__eq__': _get_eq,
|
||||
'__ne__': _get_ne,
|
||||
'__iter__': _get_iter,
|
||||
'__aiter__': _get_async_iter
|
||||
}
|
||||
|
||||
|
||||
|
@ -1879,8 +1929,33 @@ class NonCallableMagicMock(MagicMixin, NonCallableMock):
|
|||
self._mock_set_magics()
|
||||
|
||||
|
||||
class AsyncMagicMixin:
|
||||
def __init__(self, *args, **kw):
|
||||
self._mock_set_async_magics() # make magic work for kwargs in init
|
||||
_safe_super(AsyncMagicMixin, self).__init__(*args, **kw)
|
||||
self._mock_set_async_magics() # fix magic broken by upper level init
|
||||
|
||||
class MagicMock(MagicMixin, Mock):
|
||||
def _mock_set_async_magics(self):
|
||||
these_magics = _async_magics
|
||||
|
||||
if getattr(self, "_mock_methods", None) is not None:
|
||||
these_magics = _async_magics.intersection(self._mock_methods)
|
||||
remove_magics = _async_magics - these_magics
|
||||
|
||||
for entry in remove_magics:
|
||||
if entry in type(self).__dict__:
|
||||
# remove unneeded magic methods
|
||||
delattr(self, entry)
|
||||
|
||||
# don't overwrite existing attributes if called a second time
|
||||
these_magics = these_magics - set(type(self).__dict__)
|
||||
|
||||
_type = type(self)
|
||||
for entry in these_magics:
|
||||
setattr(_type, entry, MagicProxy(entry, self))
|
||||
|
||||
|
||||
class MagicMock(MagicMixin, AsyncMagicMixin, Mock):
|
||||
"""
|
||||
MagicMock is a subclass of Mock with default implementations
|
||||
of most of the magic methods. You can use MagicMock without having to
|
||||
|
@ -1920,6 +1995,218 @@ class MagicProxy(object):
|
|||
return self.create_mock()
|
||||
|
||||
|
||||
class AsyncMockMixin(Base):
|
||||
awaited = _delegating_property('awaited')
|
||||
await_count = _delegating_property('await_count')
|
||||
await_args = _delegating_property('await_args')
|
||||
await_args_list = _delegating_property('await_args_list')
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# asyncio.iscoroutinefunction() checks _is_coroutine property to say if an
|
||||
# object is a coroutine. Without this check it looks to see if it is a
|
||||
# function/method, which in this case it is not (since it is an
|
||||
# AsyncMock).
|
||||
# It is set through __dict__ because when spec_set is True, this
|
||||
# attribute is likely undefined.
|
||||
self.__dict__['_is_coroutine'] = asyncio.coroutines._is_coroutine
|
||||
self.__dict__['_mock_awaited'] = _AwaitEvent(self)
|
||||
self.__dict__['_mock_await_count'] = 0
|
||||
self.__dict__['_mock_await_args'] = None
|
||||
self.__dict__['_mock_await_args_list'] = _CallList()
|
||||
code_mock = NonCallableMock(spec_set=CodeType)
|
||||
code_mock.co_flags = inspect.CO_COROUTINE
|
||||
self.__dict__['__code__'] = code_mock
|
||||
|
||||
async def _mock_call(_mock_self, *args, **kwargs):
|
||||
self = _mock_self
|
||||
try:
|
||||
result = super()._mock_call(*args, **kwargs)
|
||||
except (BaseException, StopIteration) as e:
|
||||
side_effect = self.side_effect
|
||||
if side_effect is not None and not callable(side_effect):
|
||||
raise
|
||||
return await _raise(e)
|
||||
|
||||
_call = self.call_args
|
||||
|
||||
async def proxy():
|
||||
try:
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
else:
|
||||
return result
|
||||
finally:
|
||||
self.await_count += 1
|
||||
self.await_args = _call
|
||||
self.await_args_list.append(_call)
|
||||
await self.awaited._notify()
|
||||
|
||||
return await proxy()
|
||||
|
||||
def assert_awaited(_mock_self):
|
||||
"""
|
||||
Assert that the mock was awaited at least once.
|
||||
"""
|
||||
self = _mock_self
|
||||
if self.await_count == 0:
|
||||
msg = f"Expected {self._mock_name or 'mock'} to have been awaited."
|
||||
raise AssertionError(msg)
|
||||
|
||||
def assert_awaited_once(_mock_self):
|
||||
"""
|
||||
Assert that the mock was awaited exactly once.
|
||||
"""
|
||||
self = _mock_self
|
||||
if not self.await_count == 1:
|
||||
msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once."
|
||||
f" Awaited {self.await_count} times.")
|
||||
raise AssertionError(msg)
|
||||
|
||||
def assert_awaited_with(_mock_self, *args, **kwargs):
|
||||
"""
|
||||
Assert that the last await was with the specified arguments.
|
||||
"""
|
||||
self = _mock_self
|
||||
if self.await_args is None:
|
||||
expected = self._format_mock_call_signature(args, kwargs)
|
||||
raise AssertionError(f'Expected await: {expected}\nNot awaited')
|
||||
|
||||
def _error_message():
|
||||
msg = self._format_mock_failure_message(args, kwargs)
|
||||
return msg
|
||||
|
||||
expected = self._call_matcher((args, kwargs))
|
||||
actual = self._call_matcher(self.await_args)
|
||||
if expected != actual:
|
||||
cause = expected if isinstance(expected, Exception) else None
|
||||
raise AssertionError(_error_message()) from cause
|
||||
|
||||
def assert_awaited_once_with(_mock_self, *args, **kwargs):
|
||||
"""
|
||||
Assert that the mock was awaited exactly once and with the specified
|
||||
arguments.
|
||||
"""
|
||||
self = _mock_self
|
||||
if not self.await_count == 1:
|
||||
msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once."
|
||||
f" Awaited {self.await_count} times.")
|
||||
raise AssertionError(msg)
|
||||
return self.assert_awaited_with(*args, **kwargs)
|
||||
|
||||
def assert_any_await(_mock_self, *args, **kwargs):
|
||||
"""
|
||||
Assert the mock has ever been awaited with the specified arguments.
|
||||
"""
|
||||
self = _mock_self
|
||||
expected = self._call_matcher((args, kwargs))
|
||||
actual = [self._call_matcher(c) for c in self.await_args_list]
|
||||
if expected not in actual:
|
||||
cause = expected if isinstance(expected, Exception) else None
|
||||
expected_string = self._format_mock_call_signature(args, kwargs)
|
||||
raise AssertionError(
|
||||
'%s await not found' % expected_string
|
||||
) from cause
|
||||
|
||||
def assert_has_awaits(_mock_self, calls, any_order=False):
|
||||
"""
|
||||
Assert the mock has been awaited with the specified calls.
|
||||
The :attr:`await_args_list` list is checked for the awaits.
|
||||
|
||||
If `any_order` is False (the default) then the awaits must be
|
||||
sequential. There can be extra calls before or after the
|
||||
specified awaits.
|
||||
|
||||
If `any_order` is True then the awaits can be in any order, but
|
||||
they must all appear in :attr:`await_args_list`.
|
||||
"""
|
||||
self = _mock_self
|
||||
expected = [self._call_matcher(c) for c in calls]
|
||||
cause = expected if isinstance(expected, Exception) else None
|
||||
all_awaits = _CallList(self._call_matcher(c) for c in self.await_args_list)
|
||||
if not any_order:
|
||||
if expected not in all_awaits:
|
||||
raise AssertionError(
|
||||
f'Awaits not found.\nExpected: {_CallList(calls)}\n',
|
||||
f'Actual: {self.await_args_list}'
|
||||
) from cause
|
||||
return
|
||||
|
||||
all_awaits = list(all_awaits)
|
||||
|
||||
not_found = []
|
||||
for kall in expected:
|
||||
try:
|
||||
all_awaits.remove(kall)
|
||||
except ValueError:
|
||||
not_found.append(kall)
|
||||
if not_found:
|
||||
raise AssertionError(
|
||||
'%r not all found in await list' % (tuple(not_found),)
|
||||
) from cause
|
||||
|
||||
def assert_not_awaited(_mock_self):
|
||||
"""
|
||||
Assert that the mock was never awaited.
|
||||
"""
|
||||
self = _mock_self
|
||||
if self.await_count != 0:
|
||||
msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once."
|
||||
f" Awaited {self.await_count} times.")
|
||||
raise AssertionError(msg)
|
||||
|
||||
def reset_mock(self, *args, **kwargs):
|
||||
"""
|
||||
See :func:`.Mock.reset_mock()`
|
||||
"""
|
||||
super().reset_mock(*args, **kwargs)
|
||||
self.await_count = 0
|
||||
self.await_args = None
|
||||
self.await_args_list = _CallList()
|
||||
|
||||
|
||||
class AsyncMock(AsyncMockMixin, AsyncMagicMixin, Mock):
|
||||
"""
|
||||
Enhance :class:`Mock` with features allowing to mock
|
||||
an async function.
|
||||
|
||||
The :class:`AsyncMock` object will behave so the object is
|
||||
recognized as an async function, and the result of a call is an awaitable:
|
||||
|
||||
>>> mock = AsyncMock()
|
||||
>>> asyncio.iscoroutinefunction(mock)
|
||||
True
|
||||
>>> inspect.isawaitable(mock())
|
||||
True
|
||||
|
||||
|
||||
The result of ``mock()`` is an async function which will have the outcome
|
||||
of ``side_effect`` or ``return_value``:
|
||||
|
||||
- if ``side_effect`` is a function, the async function will return the
|
||||
result of that function,
|
||||
- if ``side_effect`` is an exception, the async function will raise the
|
||||
exception,
|
||||
- if ``side_effect`` is an iterable, the async function will return the
|
||||
next value of the iterable, however, if the sequence of result is
|
||||
exhausted, ``StopIteration`` is raised immediately,
|
||||
- if ``side_effect`` is not defined, the async function will return the
|
||||
value defined by ``return_value``, hence, by default, the async function
|
||||
returns a new :class:`AsyncMock` object.
|
||||
|
||||
If the outcome of ``side_effect`` or ``return_value`` is an async function,
|
||||
the mock async function obtained when the mock object is called will be this
|
||||
async function itself (and not an async function returning an async
|
||||
function).
|
||||
|
||||
The test author can also specify a wrapped object with ``wraps``. In this
|
||||
case, the :class:`Mock` object behavior is the same as with an
|
||||
:class:`.Mock` object: the wrapped object may have methods
|
||||
defined as async function functions.
|
||||
|
||||
Based on Martin Richard's asyntest project.
|
||||
"""
|
||||
|
||||
|
||||
class _ANY(object):
|
||||
"A helper object that compares equal to everything."
|
||||
|
@ -2145,7 +2432,6 @@ class _Call(tuple):
|
|||
call = _Call(from_kall=False)
|
||||
|
||||
|
||||
|
||||
def create_autospec(spec, spec_set=False, instance=False, _parent=None,
|
||||
_name=None, **kwargs):
|
||||
"""Create a mock object using another object as a spec. Attributes on the
|
||||
|
@ -2171,7 +2457,10 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
|
|||
spec = type(spec)
|
||||
|
||||
is_type = isinstance(spec, type)
|
||||
|
||||
if getattr(spec, '__code__', None):
|
||||
is_async_func = asyncio.iscoroutinefunction(spec)
|
||||
else:
|
||||
is_async_func = False
|
||||
_kwargs = {'spec': spec}
|
||||
if spec_set:
|
||||
_kwargs = {'spec_set': spec}
|
||||
|
@ -2188,6 +2477,11 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
|
|||
# descriptors don't have a spec
|
||||
# because we don't know what type they return
|
||||
_kwargs = {}
|
||||
elif is_async_func:
|
||||
if instance:
|
||||
raise RuntimeError("Instance can not be True when create_autospec "
|
||||
"is mocking an async function")
|
||||
Klass = AsyncMock
|
||||
elif not _callable(spec):
|
||||
Klass = NonCallableMagicMock
|
||||
elif is_type and instance and not _instance_callable(spec):
|
||||
|
@ -2204,9 +2498,26 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
|
|||
name=_name, **_kwargs)
|
||||
|
||||
if isinstance(spec, FunctionTypes):
|
||||
wrapped_mock = mock
|
||||
# should only happen at the top level because we don't
|
||||
# recurse for functions
|
||||
mock = _set_signature(mock, spec)
|
||||
if is_async_func:
|
||||
mock._is_coroutine = asyncio.coroutines._is_coroutine
|
||||
mock.await_count = 0
|
||||
mock.await_args = None
|
||||
mock.await_args_list = _CallList()
|
||||
|
||||
for a in ('assert_awaited',
|
||||
'assert_awaited_once',
|
||||
'assert_awaited_with',
|
||||
'assert_awaited_once_with',
|
||||
'assert_any_await',
|
||||
'assert_has_awaits',
|
||||
'assert_not_awaited'):
|
||||
def f(*args, **kwargs):
|
||||
return getattr(wrapped_mock, a)(*args, **kwargs)
|
||||
setattr(mock, a, f)
|
||||
else:
|
||||
_check_signature(spec, mock, is_type, instance)
|
||||
|
||||
|
@ -2250,9 +2561,13 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
|
|||
|
||||
skipfirst = _must_skip(spec, entry, is_type)
|
||||
kwargs['_eat_self'] = skipfirst
|
||||
new = MagicMock(parent=parent, name=entry, _new_name=entry,
|
||||
_new_parent=parent,
|
||||
**kwargs)
|
||||
if asyncio.iscoroutinefunction(original):
|
||||
child_klass = AsyncMock
|
||||
else:
|
||||
child_klass = MagicMock
|
||||
new = child_klass(parent=parent, name=entry, _new_name=entry,
|
||||
_new_parent=parent,
|
||||
**kwargs)
|
||||
mock._mock_children[entry] = new
|
||||
_check_signature(original, new, skipfirst=skipfirst)
|
||||
|
||||
|
@ -2438,3 +2753,60 @@ def seal(mock):
|
|||
continue
|
||||
if m._mock_new_parent is mock:
|
||||
seal(m)
|
||||
|
||||
|
||||
async def _raise(exception):
|
||||
raise exception
|
||||
|
||||
|
||||
class _AsyncIterator:
|
||||
"""
|
||||
Wraps an iterator in an asynchronous iterator.
|
||||
"""
|
||||
def __init__(self, iterator):
|
||||
self.iterator = iterator
|
||||
code_mock = NonCallableMock(spec_set=CodeType)
|
||||
code_mock.co_flags = inspect.CO_ITERABLE_COROUTINE
|
||||
self.__dict__['__code__'] = code_mock
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
return next(self.iterator)
|
||||
except StopIteration:
|
||||
pass
|
||||
raise StopAsyncIteration
|
||||
|
||||
|
||||
class _AwaitEvent:
|
||||
def __init__(self, mock):
|
||||
self._mock = mock
|
||||
self._condition = None
|
||||
|
||||
async def _notify(self):
|
||||
condition = self._get_condition()
|
||||
try:
|
||||
await condition.acquire()
|
||||
condition.notify_all()
|
||||
finally:
|
||||
condition.release()
|
||||
|
||||
def _get_condition(self):
|
||||
"""
|
||||
Creation of condition is delayed, to minimize the chance of using the
|
||||
wrong loop.
|
||||
A user may create a mock with _AwaitEvent before selecting the
|
||||
execution loop. Requiring a user to delay creation is error-prone and
|
||||
inflexible. Instead, condition is created when user actually starts to
|
||||
use the mock.
|
||||
"""
|
||||
# No synchronization is needed:
|
||||
# - asyncio is thread unsafe
|
||||
# - there are no awaits here, method will be executed without
|
||||
# switching asyncio context.
|
||||
if self._condition is None:
|
||||
self._condition = asyncio.Condition()
|
||||
|
||||
return self._condition
|
||||
|
|
|
@ -0,0 +1,549 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
from unittest.mock import call, AsyncMock, patch, MagicMock, create_autospec
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
class AsyncClass:
|
||||
def __init__(self):
|
||||
pass
|
||||
async def async_method(self):
|
||||
pass
|
||||
def normal_method(self):
|
||||
pass
|
||||
|
||||
async def async_func():
|
||||
pass
|
||||
|
||||
def normal_func():
|
||||
pass
|
||||
|
||||
class NormalClass(object):
|
||||
def a(self):
|
||||
pass
|
||||
|
||||
|
||||
async_foo_name = f'{__name__}.AsyncClass'
|
||||
normal_foo_name = f'{__name__}.NormalClass'
|
||||
|
||||
|
||||
class AsyncPatchDecoratorTest(unittest.TestCase):
|
||||
def test_is_coroutine_function_patch(self):
|
||||
@patch.object(AsyncClass, 'async_method')
|
||||
def test_async(mock_method):
|
||||
self.assertTrue(asyncio.iscoroutinefunction(mock_method))
|
||||
test_async()
|
||||
|
||||
def test_is_async_patch(self):
|
||||
@patch.object(AsyncClass, 'async_method')
|
||||
def test_async(mock_method):
|
||||
m = mock_method()
|
||||
self.assertTrue(inspect.isawaitable(m))
|
||||
asyncio.run(m)
|
||||
|
||||
@patch(f'{async_foo_name}.async_method')
|
||||
def test_no_parent_attribute(mock_method):
|
||||
m = mock_method()
|
||||
self.assertTrue(inspect.isawaitable(m))
|
||||
asyncio.run(m)
|
||||
|
||||
test_async()
|
||||
test_no_parent_attribute()
|
||||
|
||||
def test_is_AsyncMock_patch(self):
|
||||
@patch.object(AsyncClass, 'async_method')
|
||||
def test_async(mock_method):
|
||||
self.assertIsInstance(mock_method, AsyncMock)
|
||||
|
||||
test_async()
|
||||
|
||||
|
||||
class AsyncPatchCMTest(unittest.TestCase):
|
||||
def test_is_async_function_cm(self):
|
||||
def test_async():
|
||||
with patch.object(AsyncClass, 'async_method') as mock_method:
|
||||
self.assertTrue(asyncio.iscoroutinefunction(mock_method))
|
||||
|
||||
test_async()
|
||||
|
||||
def test_is_async_cm(self):
|
||||
def test_async():
|
||||
with patch.object(AsyncClass, 'async_method') as mock_method:
|
||||
m = mock_method()
|
||||
self.assertTrue(inspect.isawaitable(m))
|
||||
asyncio.run(m)
|
||||
|
||||
test_async()
|
||||
|
||||
def test_is_AsyncMock_cm(self):
|
||||
def test_async():
|
||||
with patch.object(AsyncClass, 'async_method') as mock_method:
|
||||
self.assertIsInstance(mock_method, AsyncMock)
|
||||
|
||||
test_async()
|
||||
|
||||
|
||||
class AsyncMockTest(unittest.TestCase):
|
||||
def test_iscoroutinefunction_default(self):
|
||||
mock = AsyncMock()
|
||||
self.assertTrue(asyncio.iscoroutinefunction(mock))
|
||||
|
||||
def test_iscoroutinefunction_function(self):
|
||||
async def foo(): pass
|
||||
mock = AsyncMock(foo)
|
||||
self.assertTrue(asyncio.iscoroutinefunction(mock))
|
||||
self.assertTrue(inspect.iscoroutinefunction(mock))
|
||||
|
||||
def test_isawaitable(self):
|
||||
mock = AsyncMock()
|
||||
m = mock()
|
||||
self.assertTrue(inspect.isawaitable(m))
|
||||
asyncio.run(m)
|
||||
self.assertIn('assert_awaited', dir(mock))
|
||||
|
||||
def test_iscoroutinefunction_normal_function(self):
|
||||
def foo(): pass
|
||||
mock = AsyncMock(foo)
|
||||
self.assertTrue(asyncio.iscoroutinefunction(mock))
|
||||
self.assertTrue(inspect.iscoroutinefunction(mock))
|
||||
|
||||
def test_future_isfuture(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
fut = asyncio.Future()
|
||||
loop.stop()
|
||||
loop.close()
|
||||
mock = AsyncMock(fut)
|
||||
self.assertIsInstance(mock, asyncio.Future)
|
||||
|
||||
|
||||
class AsyncAutospecTest(unittest.TestCase):
|
||||
def test_is_AsyncMock_patch(self):
|
||||
@patch(async_foo_name, autospec=True)
|
||||
def test_async(mock_method):
|
||||
self.assertIsInstance(mock_method.async_method, AsyncMock)
|
||||
self.assertIsInstance(mock_method, MagicMock)
|
||||
|
||||
@patch(async_foo_name, autospec=True)
|
||||
def test_normal_method(mock_method):
|
||||
self.assertIsInstance(mock_method.normal_method, MagicMock)
|
||||
|
||||
test_async()
|
||||
test_normal_method()
|
||||
|
||||
def test_create_autospec_instance(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
create_autospec(async_func, instance=True)
|
||||
|
||||
def test_create_autospec(self):
|
||||
spec = create_autospec(async_func)
|
||||
self.assertTrue(asyncio.iscoroutinefunction(spec))
|
||||
|
||||
|
||||
class AsyncSpecTest(unittest.TestCase):
|
||||
def test_spec_as_async_positional_magicmock(self):
|
||||
mock = MagicMock(async_func)
|
||||
self.assertIsInstance(mock, MagicMock)
|
||||
m = mock()
|
||||
self.assertTrue(inspect.isawaitable(m))
|
||||
asyncio.run(m)
|
||||
|
||||
def test_spec_as_async_kw_magicmock(self):
|
||||
mock = MagicMock(spec=async_func)
|
||||
self.assertIsInstance(mock, MagicMock)
|
||||
m = mock()
|
||||
self.assertTrue(inspect.isawaitable(m))
|
||||
asyncio.run(m)
|
||||
|
||||
def test_spec_as_async_kw_AsyncMock(self):
|
||||
mock = AsyncMock(spec=async_func)
|
||||
self.assertIsInstance(mock, AsyncMock)
|
||||
m = mock()
|
||||
self.assertTrue(inspect.isawaitable(m))
|
||||
asyncio.run(m)
|
||||
|
||||
def test_spec_as_async_positional_AsyncMock(self):
|
||||
mock = AsyncMock(async_func)
|
||||
self.assertIsInstance(mock, AsyncMock)
|
||||
m = mock()
|
||||
self.assertTrue(inspect.isawaitable(m))
|
||||
asyncio.run(m)
|
||||
|
||||
def test_spec_as_normal_kw_AsyncMock(self):
|
||||
mock = AsyncMock(spec=normal_func)
|
||||
self.assertIsInstance(mock, AsyncMock)
|
||||
m = mock()
|
||||
self.assertTrue(inspect.isawaitable(m))
|
||||
asyncio.run(m)
|
||||
|
||||
def test_spec_as_normal_positional_AsyncMock(self):
|
||||
mock = AsyncMock(normal_func)
|
||||
self.assertIsInstance(mock, AsyncMock)
|
||||
m = mock()
|
||||
self.assertTrue(inspect.isawaitable(m))
|
||||
asyncio.run(m)
|
||||
|
||||
def test_spec_async_mock(self):
|
||||
@patch.object(AsyncClass, 'async_method', spec=True)
|
||||
def test_async(mock_method):
|
||||
self.assertIsInstance(mock_method, AsyncMock)
|
||||
|
||||
test_async()
|
||||
|
||||
def test_spec_parent_not_async_attribute_is(self):
|
||||
@patch(async_foo_name, spec=True)
|
||||
def test_async(mock_method):
|
||||
self.assertIsInstance(mock_method, MagicMock)
|
||||
self.assertIsInstance(mock_method.async_method, AsyncMock)
|
||||
|
||||
test_async()
|
||||
|
||||
def test_target_async_spec_not(self):
|
||||
@patch.object(AsyncClass, 'async_method', spec=NormalClass.a)
|
||||
def test_async_attribute(mock_method):
|
||||
self.assertIsInstance(mock_method, MagicMock)
|
||||
self.assertFalse(inspect.iscoroutine(mock_method))
|
||||
self.assertFalse(inspect.isawaitable(mock_method))
|
||||
|
||||
test_async_attribute()
|
||||
|
||||
def test_target_not_async_spec_is(self):
|
||||
@patch.object(NormalClass, 'a', spec=async_func)
|
||||
def test_attribute_not_async_spec_is(mock_async_func):
|
||||
self.assertIsInstance(mock_async_func, AsyncMock)
|
||||
test_attribute_not_async_spec_is()
|
||||
|
||||
def test_spec_async_attributes(self):
|
||||
@patch(normal_foo_name, spec=AsyncClass)
|
||||
def test_async_attributes_coroutines(MockNormalClass):
|
||||
self.assertIsInstance(MockNormalClass.async_method, AsyncMock)
|
||||
self.assertIsInstance(MockNormalClass, MagicMock)
|
||||
|
||||
test_async_attributes_coroutines()
|
||||
|
||||
|
||||
class AsyncSpecSetTest(unittest.TestCase):
|
||||
def test_is_AsyncMock_patch(self):
|
||||
@patch.object(AsyncClass, 'async_method', spec_set=True)
|
||||
def test_async(async_method):
|
||||
self.assertIsInstance(async_method, AsyncMock)
|
||||
|
||||
def test_is_async_AsyncMock(self):
|
||||
mock = AsyncMock(spec_set=AsyncClass.async_method)
|
||||
self.assertTrue(asyncio.iscoroutinefunction(mock))
|
||||
self.assertIsInstance(mock, AsyncMock)
|
||||
|
||||
def test_is_child_AsyncMock(self):
|
||||
mock = MagicMock(spec_set=AsyncClass)
|
||||
self.assertTrue(asyncio.iscoroutinefunction(mock.async_method))
|
||||
self.assertFalse(asyncio.iscoroutinefunction(mock.normal_method))
|
||||
self.assertIsInstance(mock.async_method, AsyncMock)
|
||||
self.assertIsInstance(mock.normal_method, MagicMock)
|
||||
self.assertIsInstance(mock, MagicMock)
|
||||
|
||||
|
||||
class AsyncArguments(unittest.TestCase):
|
||||
def test_add_return_value(self):
|
||||
async def addition(self, var):
|
||||
return var + 1
|
||||
|
||||
mock = AsyncMock(addition, return_value=10)
|
||||
output = asyncio.run(mock(5))
|
||||
|
||||
self.assertEqual(output, 10)
|
||||
|
||||
def test_add_side_effect_exception(self):
|
||||
async def addition(var):
|
||||
return var + 1
|
||||
mock = AsyncMock(addition, side_effect=Exception('err'))
|
||||
with self.assertRaises(Exception):
|
||||
asyncio.run(mock(5))
|
||||
|
||||
def test_add_side_effect_function(self):
|
||||
async def addition(var):
|
||||
return var + 1
|
||||
mock = AsyncMock(side_effect=addition)
|
||||
result = asyncio.run(mock(5))
|
||||
self.assertEqual(result, 6)
|
||||
|
||||
def test_add_side_effect_iterable(self):
|
||||
vals = [1, 2, 3]
|
||||
mock = AsyncMock(side_effect=vals)
|
||||
for item in vals:
|
||||
self.assertEqual(item, asyncio.run(mock()))
|
||||
|
||||
with self.assertRaises(RuntimeError) as e:
|
||||
asyncio.run(mock())
|
||||
self.assertEqual(
|
||||
e.exception,
|
||||
RuntimeError('coroutine raised StopIteration')
|
||||
)
|
||||
|
||||
|
||||
class AsyncContextManagerTest(unittest.TestCase):
|
||||
class WithAsyncContextManager:
|
||||
def __init__(self):
|
||||
self.entered = False
|
||||
self.exited = False
|
||||
|
||||
async def __aenter__(self, *args, **kwargs):
|
||||
self.entered = True
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args, **kwargs):
|
||||
self.exited = True
|
||||
|
||||
def test_magic_methods_are_async_mocks(self):
|
||||
mock = MagicMock(self.WithAsyncContextManager())
|
||||
self.assertIsInstance(mock.__aenter__, AsyncMock)
|
||||
self.assertIsInstance(mock.__aexit__, AsyncMock)
|
||||
|
||||
def test_mock_supports_async_context_manager(self):
|
||||
called = False
|
||||
instance = self.WithAsyncContextManager()
|
||||
mock_instance = MagicMock(instance)
|
||||
|
||||
async def use_context_manager():
|
||||
nonlocal called
|
||||
async with mock_instance as result:
|
||||
called = True
|
||||
return result
|
||||
|
||||
result = asyncio.run(use_context_manager())
|
||||
self.assertFalse(instance.entered)
|
||||
self.assertFalse(instance.exited)
|
||||
self.assertTrue(called)
|
||||
self.assertTrue(mock_instance.entered)
|
||||
self.assertTrue(mock_instance.exited)
|
||||
self.assertTrue(mock_instance.__aenter__.called)
|
||||
self.assertTrue(mock_instance.__aexit__.called)
|
||||
self.assertIsNot(mock_instance, result)
|
||||
self.assertIsInstance(result, AsyncMock)
|
||||
|
||||
def test_mock_customize_async_context_manager(self):
|
||||
instance = self.WithAsyncContextManager()
|
||||
mock_instance = MagicMock(instance)
|
||||
|
||||
expected_result = object()
|
||||
mock_instance.__aenter__.return_value = expected_result
|
||||
|
||||
async def use_context_manager():
|
||||
async with mock_instance as result:
|
||||
return result
|
||||
|
||||
self.assertIs(asyncio.run(use_context_manager()), expected_result)
|
||||
|
||||
def test_mock_customize_async_context_manager_with_coroutine(self):
|
||||
enter_called = False
|
||||
exit_called = False
|
||||
|
||||
async def enter_coroutine(*args):
|
||||
nonlocal enter_called
|
||||
enter_called = True
|
||||
|
||||
async def exit_coroutine(*args):
|
||||
nonlocal exit_called
|
||||
exit_called = True
|
||||
|
||||
instance = self.WithAsyncContextManager()
|
||||
mock_instance = MagicMock(instance)
|
||||
|
||||
mock_instance.__aenter__ = enter_coroutine
|
||||
mock_instance.__aexit__ = exit_coroutine
|
||||
|
||||
async def use_context_manager():
|
||||
async with mock_instance:
|
||||
pass
|
||||
|
||||
asyncio.run(use_context_manager())
|
||||
self.assertTrue(enter_called)
|
||||
self.assertTrue(exit_called)
|
||||
|
||||
def test_context_manager_raise_exception_by_default(self):
|
||||
async def raise_in(context_manager):
|
||||
async with context_manager:
|
||||
raise TypeError()
|
||||
|
||||
instance = self.WithAsyncContextManager()
|
||||
mock_instance = MagicMock(instance)
|
||||
with self.assertRaises(TypeError):
|
||||
asyncio.run(raise_in(mock_instance))
|
||||
|
||||
|
||||
class AsyncIteratorTest(unittest.TestCase):
|
||||
class WithAsyncIterator(object):
|
||||
def __init__(self):
|
||||
self.items = ["foo", "NormalFoo", "baz"]
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
return self.items.pop()
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
raise StopAsyncIteration
|
||||
|
||||
def test_mock_aiter_and_anext(self):
|
||||
instance = self.WithAsyncIterator()
|
||||
mock_instance = MagicMock(instance)
|
||||
|
||||
self.assertEqual(asyncio.iscoroutine(instance.__aiter__),
|
||||
asyncio.iscoroutine(mock_instance.__aiter__))
|
||||
self.assertEqual(asyncio.iscoroutine(instance.__anext__),
|
||||
asyncio.iscoroutine(mock_instance.__anext__))
|
||||
|
||||
iterator = instance.__aiter__()
|
||||
if asyncio.iscoroutine(iterator):
|
||||
iterator = asyncio.run(iterator)
|
||||
|
||||
mock_iterator = mock_instance.__aiter__()
|
||||
if asyncio.iscoroutine(mock_iterator):
|
||||
mock_iterator = asyncio.run(mock_iterator)
|
||||
|
||||
self.assertEqual(asyncio.iscoroutine(iterator.__aiter__),
|
||||
asyncio.iscoroutine(mock_iterator.__aiter__))
|
||||
self.assertEqual(asyncio.iscoroutine(iterator.__anext__),
|
||||
asyncio.iscoroutine(mock_iterator.__anext__))
|
||||
|
||||
def test_mock_async_for(self):
|
||||
async def iterate(iterator):
|
||||
accumulator = []
|
||||
async for item in iterator:
|
||||
accumulator.append(item)
|
||||
|
||||
return accumulator
|
||||
|
||||
expected = ["FOO", "BAR", "BAZ"]
|
||||
with self.subTest("iterate through default value"):
|
||||
mock_instance = MagicMock(self.WithAsyncIterator())
|
||||
self.assertEqual([], asyncio.run(iterate(mock_instance)))
|
||||
|
||||
with self.subTest("iterate through set return_value"):
|
||||
mock_instance = MagicMock(self.WithAsyncIterator())
|
||||
mock_instance.__aiter__.return_value = expected[:]
|
||||
self.assertEqual(expected, asyncio.run(iterate(mock_instance)))
|
||||
|
||||
with self.subTest("iterate through set return_value iterator"):
|
||||
mock_instance = MagicMock(self.WithAsyncIterator())
|
||||
mock_instance.__aiter__.return_value = iter(expected[:])
|
||||
self.assertEqual(expected, asyncio.run(iterate(mock_instance)))
|
||||
|
||||
|
||||
class AsyncMockAssert(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.mock = AsyncMock()
|
||||
|
||||
async def _runnable_test(self, *args):
|
||||
if not args:
|
||||
await self.mock()
|
||||
else:
|
||||
await self.mock(*args)
|
||||
|
||||
def test_assert_awaited(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
self.mock.assert_awaited()
|
||||
|
||||
asyncio.run(self._runnable_test())
|
||||
self.mock.assert_awaited()
|
||||
|
||||
def test_assert_awaited_once(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
self.mock.assert_awaited_once()
|
||||
|
||||
asyncio.run(self._runnable_test())
|
||||
self.mock.assert_awaited_once()
|
||||
|
||||
asyncio.run(self._runnable_test())
|
||||
with self.assertRaises(AssertionError):
|
||||
self.mock.assert_awaited_once()
|
||||
|
||||
def test_assert_awaited_with(self):
|
||||
asyncio.run(self._runnable_test())
|
||||
with self.assertRaises(AssertionError):
|
||||
self.mock.assert_awaited_with('foo')
|
||||
|
||||
asyncio.run(self._runnable_test('foo'))
|
||||
self.mock.assert_awaited_with('foo')
|
||||
|
||||
asyncio.run(self._runnable_test('SomethingElse'))
|
||||
with self.assertRaises(AssertionError):
|
||||
self.mock.assert_awaited_with('foo')
|
||||
|
||||
def test_assert_awaited_once_with(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
self.mock.assert_awaited_once_with('foo')
|
||||
|
||||
asyncio.run(self._runnable_test('foo'))
|
||||
self.mock.assert_awaited_once_with('foo')
|
||||
|
||||
asyncio.run(self._runnable_test('foo'))
|
||||
with self.assertRaises(AssertionError):
|
||||
self.mock.assert_awaited_once_with('foo')
|
||||
|
||||
def test_assert_any_wait(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
self.mock.assert_any_await('NormalFoo')
|
||||
|
||||
asyncio.run(self._runnable_test('foo'))
|
||||
with self.assertRaises(AssertionError):
|
||||
self.mock.assert_any_await('NormalFoo')
|
||||
|
||||
asyncio.run(self._runnable_test('NormalFoo'))
|
||||
self.mock.assert_any_await('NormalFoo')
|
||||
|
||||
asyncio.run(self._runnable_test('SomethingElse'))
|
||||
self.mock.assert_any_await('NormalFoo')
|
||||
|
||||
def test_assert_has_awaits_no_order(self):
|
||||
calls = [call('NormalFoo'), call('baz')]
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
self.mock.assert_has_awaits(calls)
|
||||
|
||||
asyncio.run(self._runnable_test('foo'))
|
||||
with self.assertRaises(AssertionError):
|
||||
self.mock.assert_has_awaits(calls)
|
||||
|
||||
asyncio.run(self._runnable_test('NormalFoo'))
|
||||
with self.assertRaises(AssertionError):
|
||||
self.mock.assert_has_awaits(calls)
|
||||
|
||||
asyncio.run(self._runnable_test('baz'))
|
||||
self.mock.assert_has_awaits(calls)
|
||||
|
||||
asyncio.run(self._runnable_test('SomethingElse'))
|
||||
self.mock.assert_has_awaits(calls)
|
||||
|
||||
def test_assert_has_awaits_ordered(self):
|
||||
calls = [call('NormalFoo'), call('baz')]
|
||||
with self.assertRaises(AssertionError):
|
||||
self.mock.assert_has_awaits(calls, any_order=True)
|
||||
|
||||
asyncio.run(self._runnable_test('baz'))
|
||||
with self.assertRaises(AssertionError):
|
||||
self.mock.assert_has_awaits(calls, any_order=True)
|
||||
|
||||
asyncio.run(self._runnable_test('foo'))
|
||||
with self.assertRaises(AssertionError):
|
||||
self.mock.assert_has_awaits(calls, any_order=True)
|
||||
|
||||
asyncio.run(self._runnable_test('NormalFoo'))
|
||||
self.mock.assert_has_awaits(calls, any_order=True)
|
||||
|
||||
asyncio.run(self._runnable_test('qux'))
|
||||
self.mock.assert_has_awaits(calls, any_order=True)
|
||||
|
||||
def test_assert_not_awaited(self):
|
||||
self.mock.assert_not_awaited()
|
||||
|
||||
asyncio.run(self._runnable_test())
|
||||
with self.assertRaises(AssertionError):
|
||||
self.mock.assert_not_awaited()
|
|
@ -9,7 +9,7 @@ from unittest import mock
|
|||
from unittest.mock import (
|
||||
call, DEFAULT, patch, sentinel,
|
||||
MagicMock, Mock, NonCallableMock,
|
||||
NonCallableMagicMock, _Call, _CallList,
|
||||
NonCallableMagicMock, AsyncMock, _Call, _CallList,
|
||||
create_autospec
|
||||
)
|
||||
|
||||
|
@ -1618,7 +1618,8 @@ class MockTest(unittest.TestCase):
|
|||
|
||||
|
||||
def test_adding_child_mock(self):
|
||||
for Klass in NonCallableMock, Mock, MagicMock, NonCallableMagicMock:
|
||||
for Klass in (NonCallableMock, Mock, MagicMock, NonCallableMagicMock,
|
||||
AsyncMock):
|
||||
mock = Klass()
|
||||
|
||||
mock.foo = Mock()
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
Added AsyncMock to support using unittest to mock asyncio coroutines.
|
||||
Patch by Lisa Roach.
|
Loading…
Reference in New Issue