bpo-26467: Adds AsyncMock for asyncio Mock library support (GH-9296)

This commit is contained in:
Lisa Roach 2019-05-20 09:19:53 -07:00 committed by GitHub
parent 0f72147ce2
commit 77b3b7701a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 1161 additions and 20 deletions

View File

@ -201,9 +201,11 @@ The Mock Class
.. testsetup::
import asyncio
import inspect
import unittest
from unittest.mock import sentinel, DEFAULT, ANY
from unittest.mock import patch, call, Mock, MagicMock, PropertyMock
from unittest.mock import patch, call, Mock, MagicMock, PropertyMock, AsyncMock
from unittest.mock import mock_open
:class:`Mock` is a flexible mock object intended to replace the use of stubs and
@ -851,6 +853,217 @@ object::
>>> p.assert_called_once_with()
.. class:: AsyncMock(spec=None, side_effect=None, return_value=DEFAULT, wraps=None, name=None, spec_set=None, unsafe=False, **kwargs)
An asynchronous version of :class:`Mock`. The :class:`AsyncMock` object will
behave so the object is recognized as an async function, and the result of a
call is an awaitable.
>>> mock = AsyncMock()
>>> asyncio.iscoroutinefunction(mock)
True
>>> inspect.isawaitable(mock())
True
The result of ``mock()`` is an async function which will have the outcome
of ``side_effect`` or ``return_value``:
- if ``side_effect`` is a function, the async function will return the
result of that function,
- if ``side_effect`` is an exception, the async function will raise the
exception,
- if ``side_effect`` is an iterable, the async function will return the
next value of the iterable, however, if the sequence of result is
exhausted, ``StopIteration`` is raised immediately,
- if ``side_effect`` is not defined, the async function will return the
value defined by ``return_value``, hence, by default, the async function
returns a new :class:`AsyncMock` object.
Setting the *spec* of a :class:`Mock` or :class:`MagicMock` to an async function
will result in a coroutine object being returned after calling.
>>> async def async_func(): pass
...
>>> mock = MagicMock(async_func)
>>> mock
<MagicMock spec='function' id='...'>
>>> mock()
<coroutine object AsyncMockMixin._mock_call at ...>
.. method:: assert_awaited()
Assert that the mock was awaited at least once.
>>> mock = AsyncMock()
>>> async def main():
... await mock()
...
>>> asyncio.run(main())
>>> mock.assert_awaited()
>>> mock_2 = AsyncMock()
>>> mock_2.assert_awaited()
Traceback (most recent call last):
...
AssertionError: Expected mock to have been awaited.
.. method:: assert_awaited_once()
Assert that the mock was awaited exactly once.
>>> mock = AsyncMock()
>>> async def main():
... await mock()
...
>>> asyncio.run(main())
>>> mock.assert_awaited_once()
>>> asyncio.run(main())
>>> mock.method.assert_awaited_once()
Traceback (most recent call last):
...
AssertionError: Expected mock to have been awaited once. Awaited 2 times.
.. method:: assert_awaited_with(*args, **kwargs)
Assert that the last await was with the specified arguments.
>>> mock = AsyncMock()
>>> async def main(*args, **kwargs):
... await mock(*args, **kwargs)
...
>>> asyncio.run(main('foo', bar='bar'))
>>> mock.assert_awaited_with('foo', bar='bar')
>>> mock.assert_awaited_with('other')
Traceback (most recent call last):
...
AssertionError: expected call not found.
Expected: mock('other')
Actual: mock('foo', bar='bar')
.. method:: assert_awaited_once_with(*args, **kwargs)
Assert that the mock was awaited exactly once and with the specified
arguments.
>>> mock = AsyncMock()
>>> async def main(*args, **kwargs):
... await mock(*args, **kwargs)
...
>>> asyncio.run(main('foo', bar='bar'))
>>> mock.assert_awaited_once_with('foo', bar='bar')
>>> asyncio.run(main('foo', bar='bar'))
>>> mock.assert_awaited_once_with('foo', bar='bar')
Traceback (most recent call last):
...
AssertionError: Expected mock to have been awaited once. Awaited 2 times.
.. method:: assert_any_await(*args, **kwargs)
Assert the mock has ever been awaited with the specified arguments.
>>> mock = AsyncMock()
>>> async def main(*args, **kwargs):
... await mock(*args, **kwargs)
...
>>> asyncio.run(main('foo', bar='bar'))
>>> asyncio.run(main('hello'))
>>> mock.assert_any_await('foo', bar='bar')
>>> mock.assert_any_await('other')
Traceback (most recent call last):
...
AssertionError: mock('other') await not found
.. method:: assert_has_awaits(calls, any_order=False)
Assert the mock has been awaited with the specified calls.
The :attr:`await_args_list` list is checked for the awaits.
If *any_order* is False (the default) then the awaits must be
sequential. There can be extra calls before or after the
specified awaits.
If *any_order* is True then the awaits can be in any order, but
they must all appear in :attr:`await_args_list`.
>>> mock = AsyncMock()
>>> async def main(*args, **kwargs):
... await mock(*args, **kwargs)
...
>>> calls = [call("foo"), call("bar")]
>>> mock.assert_has_calls(calls)
Traceback (most recent call last):
...
AssertionError: Calls not found.
Expected: [call('foo'), call('bar')]
>>> asyncio.run(main('foo'))
>>> asyncio.run(main('bar'))
>>> mock.assert_has_calls(calls)
.. method:: assert_not_awaited()
Assert that the mock was never awaited.
>>> mock = AsyncMock()
>>> mock.assert_not_awaited()
.. method:: reset_mock(*args, **kwargs)
See :func:`Mock.reset_mock`. Also sets :attr:`await_count` to 0,
:attr:`await_args` to None, and clears the :attr:`await_args_list`.
.. attribute:: await_count
An integer keeping track of how many times the mock object has been awaited.
>>> mock = AsyncMock()
>>> async def main():
... await mock()
...
>>> asyncio.run(main())
>>> mock.await_count
1
>>> asyncio.run(main())
>>> mock.await_count
2
.. attribute:: await_args
This is either ``None`` (if the mock hasnt been awaited), or the arguments that
the mock was last awaited with. Functions the same as :attr:`Mock.call_args`.
>>> mock = AsyncMock()
>>> async def main(*args):
... await mock(*args)
...
>>> mock.await_args
>>> asyncio.run(main('foo'))
>>> mock.await_args
call('foo')
>>> asyncio.run(main('bar'))
>>> mock.await_args
call('bar')
.. attribute:: await_args_list
This is a list of all the awaits made to the mock object in sequence (so the
length of the list is the number of times it has been awaited). Before any
awaits have been made it is an empty list.
>>> mock = AsyncMock()
>>> async def main(*args):
... await mock(*args)
...
>>> mock.await_args_list
[]
>>> asyncio.run(main('foo'))
>>> mock.await_args_list
[call('foo')]
>>> asyncio.run(main('bar'))
>>> mock.await_args_list
[call('foo'), call('bar')]
Calling
~~~~~~~

View File

@ -538,6 +538,10 @@ unicodedata
unittest
--------
* XXX Added :class:`AsyncMock` to support an asynchronous version of :class:`Mock`.
Appropriate new assert functions for testing have been added as well.
(Contributed by Lisa Roach in :issue:`26467`).
* Added :func:`~unittest.addModuleCleanup()` and
:meth:`~unittest.TestCase.addClassCleanup()` to unittest to support
cleanups for :func:`~unittest.setUpModule()` and

View File

@ -13,6 +13,7 @@ __all__ = (
'ANY',
'call',
'create_autospec',
'AsyncMock',
'FILTER_DIR',
'NonCallableMock',
'NonCallableMagicMock',
@ -24,13 +25,13 @@ __all__ = (
__version__ = '1.0'
import asyncio
import io
import inspect
import pprint
import sys
import builtins
from types import ModuleType, MethodType
from types import CodeType, ModuleType, MethodType
from unittest.util import safe_repr
from functools import wraps, partial
@ -43,6 +44,13 @@ FILTER_DIR = True
# Without this, the __class__ properties wouldn't be set correctly
_safe_super = super
def _is_async_obj(obj):
if getattr(obj, '__code__', None):
return asyncio.iscoroutinefunction(obj) or inspect.isawaitable(obj)
else:
return False
def _is_instance_mock(obj):
# can't use isinstance on Mock objects because they override __class__
# The base class for all mocks is NonCallableMock
@ -355,7 +363,20 @@ class NonCallableMock(Base):
# every instance has its own class
# so we can create magic methods on the
# class without stomping on other mocks
new = type(cls.__name__, (cls,), {'__doc__': cls.__doc__})
bases = (cls,)
if not issubclass(cls, AsyncMock):
# Check if spec is an async object or function
sig = inspect.signature(NonCallableMock.__init__)
bound_args = sig.bind_partial(cls, *args, **kw).arguments
spec_arg = [
arg for arg in bound_args.keys()
if arg.startswith('spec')
]
if spec_arg:
# what if spec_set is different than spec?
if _is_async_obj(bound_args[spec_arg[0]]):
bases = (AsyncMockMixin, cls,)
new = type(cls.__name__, bases, {'__doc__': cls.__doc__})
instance = object.__new__(new)
return instance
@ -431,6 +452,11 @@ class NonCallableMock(Base):
_eat_self=False):
_spec_class = None
_spec_signature = None
_spec_asyncs = []
for attr in dir(spec):
if asyncio.iscoroutinefunction(getattr(spec, attr, None)):
_spec_asyncs.append(attr)
if spec is not None and not _is_list(spec):
if isinstance(spec, type):
@ -448,7 +474,7 @@ class NonCallableMock(Base):
__dict__['_spec_set'] = spec_set
__dict__['_spec_signature'] = _spec_signature
__dict__['_mock_methods'] = spec
__dict__['_spec_asyncs'] = _spec_asyncs
def __get_return_value(self):
ret = self._mock_return_value
@ -886,7 +912,15 @@ class NonCallableMock(Base):
For non-callable mocks the callable variant will be used (rather than
any custom subclass)."""
_new_name = kw.get("_new_name")
if _new_name in self.__dict__['_spec_asyncs']:
return AsyncMock(**kw)
_type = type(self)
if issubclass(_type, MagicMock) and _new_name in _async_method_magics:
klass = AsyncMock
if issubclass(_type, AsyncMockMixin):
klass = MagicMock
if not issubclass(_type, CallableMixin):
if issubclass(_type, NonCallableMagicMock):
klass = MagicMock
@ -932,14 +966,12 @@ def _try_iter(obj):
return obj
class CallableMixin(Base):
def __init__(self, spec=None, side_effect=None, return_value=DEFAULT,
wraps=None, name=None, spec_set=None, parent=None,
_spec_state=None, _new_name='', _new_parent=None, **kwargs):
self.__dict__['_mock_return_value'] = return_value
_safe_super(CallableMixin, self).__init__(
spec, wraps, name, spec_set, parent,
_spec_state, _new_name, _new_parent, **kwargs
@ -1081,7 +1113,6 @@ class Mock(CallableMixin, NonCallableMock):
"""
def _dot_lookup(thing, comp, import_path):
try:
return getattr(thing, comp)
@ -1279,8 +1310,10 @@ class _patch(object):
if isinstance(original, type):
# If we're patching out a class and there is a spec
inherit = True
Klass = MagicMock
if spec is None and _is_async_obj(original):
Klass = AsyncMock
else:
Klass = MagicMock
_kwargs = {}
if new_callable is not None:
Klass = new_callable
@ -1292,7 +1325,9 @@ class _patch(object):
not_callable = '__call__' not in this_spec
else:
not_callable = not callable(this_spec)
if not_callable:
if _is_async_obj(this_spec):
Klass = AsyncMock
elif not_callable:
Klass = NonCallableMagicMock
if spec is not None:
@ -1733,7 +1768,7 @@ _non_defaults = {
'__reduce__', '__reduce_ex__', '__getinitargs__', '__getnewargs__',
'__getstate__', '__setstate__', '__getformat__', '__setformat__',
'__repr__', '__dir__', '__subclasses__', '__format__',
'__getnewargs_ex__',
'__getnewargs_ex__', '__aenter__', '__aexit__', '__anext__', '__aiter__',
}
@ -1750,6 +1785,11 @@ _magics = {
' '.join([magic_methods, numerics, inplace, right]).split()
}
# Magic methods used for async `with` statements
_async_method_magics = {"__aenter__", "__aexit__", "__anext__"}
# `__aiter__` is a plain function but used with async calls
_async_magics = _async_method_magics | {"__aiter__"}
_all_magics = _magics | _non_defaults
_unsupported_magics = {
@ -1779,6 +1819,7 @@ _return_values = {
'__float__': 1.0,
'__bool__': True,
'__index__': 1,
'__aexit__': False,
}
@ -1811,10 +1852,19 @@ def _get_iter(self):
return iter(ret_val)
return __iter__
def _get_async_iter(self):
def __aiter__():
ret_val = self.__aiter__._mock_return_value
if ret_val is DEFAULT:
return _AsyncIterator(iter([]))
return _AsyncIterator(iter(ret_val))
return __aiter__
_side_effect_methods = {
'__eq__': _get_eq,
'__ne__': _get_ne,
'__iter__': _get_iter,
'__aiter__': _get_async_iter
}
@ -1879,8 +1929,33 @@ class NonCallableMagicMock(MagicMixin, NonCallableMock):
self._mock_set_magics()
class AsyncMagicMixin:
def __init__(self, *args, **kw):
self._mock_set_async_magics() # make magic work for kwargs in init
_safe_super(AsyncMagicMixin, self).__init__(*args, **kw)
self._mock_set_async_magics() # fix magic broken by upper level init
class MagicMock(MagicMixin, Mock):
def _mock_set_async_magics(self):
these_magics = _async_magics
if getattr(self, "_mock_methods", None) is not None:
these_magics = _async_magics.intersection(self._mock_methods)
remove_magics = _async_magics - these_magics
for entry in remove_magics:
if entry in type(self).__dict__:
# remove unneeded magic methods
delattr(self, entry)
# don't overwrite existing attributes if called a second time
these_magics = these_magics - set(type(self).__dict__)
_type = type(self)
for entry in these_magics:
setattr(_type, entry, MagicProxy(entry, self))
class MagicMock(MagicMixin, AsyncMagicMixin, Mock):
"""
MagicMock is a subclass of Mock with default implementations
of most of the magic methods. You can use MagicMock without having to
@ -1920,6 +1995,218 @@ class MagicProxy(object):
return self.create_mock()
class AsyncMockMixin(Base):
awaited = _delegating_property('awaited')
await_count = _delegating_property('await_count')
await_args = _delegating_property('await_args')
await_args_list = _delegating_property('await_args_list')
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# asyncio.iscoroutinefunction() checks _is_coroutine property to say if an
# object is a coroutine. Without this check it looks to see if it is a
# function/method, which in this case it is not (since it is an
# AsyncMock).
# It is set through __dict__ because when spec_set is True, this
# attribute is likely undefined.
self.__dict__['_is_coroutine'] = asyncio.coroutines._is_coroutine
self.__dict__['_mock_awaited'] = _AwaitEvent(self)
self.__dict__['_mock_await_count'] = 0
self.__dict__['_mock_await_args'] = None
self.__dict__['_mock_await_args_list'] = _CallList()
code_mock = NonCallableMock(spec_set=CodeType)
code_mock.co_flags = inspect.CO_COROUTINE
self.__dict__['__code__'] = code_mock
async def _mock_call(_mock_self, *args, **kwargs):
self = _mock_self
try:
result = super()._mock_call(*args, **kwargs)
except (BaseException, StopIteration) as e:
side_effect = self.side_effect
if side_effect is not None and not callable(side_effect):
raise
return await _raise(e)
_call = self.call_args
async def proxy():
try:
if inspect.isawaitable(result):
return await result
else:
return result
finally:
self.await_count += 1
self.await_args = _call
self.await_args_list.append(_call)
await self.awaited._notify()
return await proxy()
def assert_awaited(_mock_self):
"""
Assert that the mock was awaited at least once.
"""
self = _mock_self
if self.await_count == 0:
msg = f"Expected {self._mock_name or 'mock'} to have been awaited."
raise AssertionError(msg)
def assert_awaited_once(_mock_self):
"""
Assert that the mock was awaited exactly once.
"""
self = _mock_self
if not self.await_count == 1:
msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once."
f" Awaited {self.await_count} times.")
raise AssertionError(msg)
def assert_awaited_with(_mock_self, *args, **kwargs):
"""
Assert that the last await was with the specified arguments.
"""
self = _mock_self
if self.await_args is None:
expected = self._format_mock_call_signature(args, kwargs)
raise AssertionError(f'Expected await: {expected}\nNot awaited')
def _error_message():
msg = self._format_mock_failure_message(args, kwargs)
return msg
expected = self._call_matcher((args, kwargs))
actual = self._call_matcher(self.await_args)
if expected != actual:
cause = expected if isinstance(expected, Exception) else None
raise AssertionError(_error_message()) from cause
def assert_awaited_once_with(_mock_self, *args, **kwargs):
"""
Assert that the mock was awaited exactly once and with the specified
arguments.
"""
self = _mock_self
if not self.await_count == 1:
msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once."
f" Awaited {self.await_count} times.")
raise AssertionError(msg)
return self.assert_awaited_with(*args, **kwargs)
def assert_any_await(_mock_self, *args, **kwargs):
"""
Assert the mock has ever been awaited with the specified arguments.
"""
self = _mock_self
expected = self._call_matcher((args, kwargs))
actual = [self._call_matcher(c) for c in self.await_args_list]
if expected not in actual:
cause = expected if isinstance(expected, Exception) else None
expected_string = self._format_mock_call_signature(args, kwargs)
raise AssertionError(
'%s await not found' % expected_string
) from cause
def assert_has_awaits(_mock_self, calls, any_order=False):
"""
Assert the mock has been awaited with the specified calls.
The :attr:`await_args_list` list is checked for the awaits.
If `any_order` is False (the default) then the awaits must be
sequential. There can be extra calls before or after the
specified awaits.
If `any_order` is True then the awaits can be in any order, but
they must all appear in :attr:`await_args_list`.
"""
self = _mock_self
expected = [self._call_matcher(c) for c in calls]
cause = expected if isinstance(expected, Exception) else None
all_awaits = _CallList(self._call_matcher(c) for c in self.await_args_list)
if not any_order:
if expected not in all_awaits:
raise AssertionError(
f'Awaits not found.\nExpected: {_CallList(calls)}\n',
f'Actual: {self.await_args_list}'
) from cause
return
all_awaits = list(all_awaits)
not_found = []
for kall in expected:
try:
all_awaits.remove(kall)
except ValueError:
not_found.append(kall)
if not_found:
raise AssertionError(
'%r not all found in await list' % (tuple(not_found),)
) from cause
def assert_not_awaited(_mock_self):
"""
Assert that the mock was never awaited.
"""
self = _mock_self
if self.await_count != 0:
msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once."
f" Awaited {self.await_count} times.")
raise AssertionError(msg)
def reset_mock(self, *args, **kwargs):
"""
See :func:`.Mock.reset_mock()`
"""
super().reset_mock(*args, **kwargs)
self.await_count = 0
self.await_args = None
self.await_args_list = _CallList()
class AsyncMock(AsyncMockMixin, AsyncMagicMixin, Mock):
"""
Enhance :class:`Mock` with features allowing to mock
an async function.
The :class:`AsyncMock` object will behave so the object is
recognized as an async function, and the result of a call is an awaitable:
>>> mock = AsyncMock()
>>> asyncio.iscoroutinefunction(mock)
True
>>> inspect.isawaitable(mock())
True
The result of ``mock()`` is an async function which will have the outcome
of ``side_effect`` or ``return_value``:
- if ``side_effect`` is a function, the async function will return the
result of that function,
- if ``side_effect`` is an exception, the async function will raise the
exception,
- if ``side_effect`` is an iterable, the async function will return the
next value of the iterable, however, if the sequence of result is
exhausted, ``StopIteration`` is raised immediately,
- if ``side_effect`` is not defined, the async function will return the
value defined by ``return_value``, hence, by default, the async function
returns a new :class:`AsyncMock` object.
If the outcome of ``side_effect`` or ``return_value`` is an async function,
the mock async function obtained when the mock object is called will be this
async function itself (and not an async function returning an async
function).
The test author can also specify a wrapped object with ``wraps``. In this
case, the :class:`Mock` object behavior is the same as with an
:class:`.Mock` object: the wrapped object may have methods
defined as async function functions.
Based on Martin Richard's asyntest project.
"""
class _ANY(object):
"A helper object that compares equal to everything."
@ -2145,7 +2432,6 @@ class _Call(tuple):
call = _Call(from_kall=False)
def create_autospec(spec, spec_set=False, instance=False, _parent=None,
_name=None, **kwargs):
"""Create a mock object using another object as a spec. Attributes on the
@ -2171,7 +2457,10 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
spec = type(spec)
is_type = isinstance(spec, type)
if getattr(spec, '__code__', None):
is_async_func = asyncio.iscoroutinefunction(spec)
else:
is_async_func = False
_kwargs = {'spec': spec}
if spec_set:
_kwargs = {'spec_set': spec}
@ -2188,6 +2477,11 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
# descriptors don't have a spec
# because we don't know what type they return
_kwargs = {}
elif is_async_func:
if instance:
raise RuntimeError("Instance can not be True when create_autospec "
"is mocking an async function")
Klass = AsyncMock
elif not _callable(spec):
Klass = NonCallableMagicMock
elif is_type and instance and not _instance_callable(spec):
@ -2204,9 +2498,26 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
name=_name, **_kwargs)
if isinstance(spec, FunctionTypes):
wrapped_mock = mock
# should only happen at the top level because we don't
# recurse for functions
mock = _set_signature(mock, spec)
if is_async_func:
mock._is_coroutine = asyncio.coroutines._is_coroutine
mock.await_count = 0
mock.await_args = None
mock.await_args_list = _CallList()
for a in ('assert_awaited',
'assert_awaited_once',
'assert_awaited_with',
'assert_awaited_once_with',
'assert_any_await',
'assert_has_awaits',
'assert_not_awaited'):
def f(*args, **kwargs):
return getattr(wrapped_mock, a)(*args, **kwargs)
setattr(mock, a, f)
else:
_check_signature(spec, mock, is_type, instance)
@ -2250,9 +2561,13 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
skipfirst = _must_skip(spec, entry, is_type)
kwargs['_eat_self'] = skipfirst
new = MagicMock(parent=parent, name=entry, _new_name=entry,
_new_parent=parent,
**kwargs)
if asyncio.iscoroutinefunction(original):
child_klass = AsyncMock
else:
child_klass = MagicMock
new = child_klass(parent=parent, name=entry, _new_name=entry,
_new_parent=parent,
**kwargs)
mock._mock_children[entry] = new
_check_signature(original, new, skipfirst=skipfirst)
@ -2438,3 +2753,60 @@ def seal(mock):
continue
if m._mock_new_parent is mock:
seal(m)
async def _raise(exception):
raise exception
class _AsyncIterator:
"""
Wraps an iterator in an asynchronous iterator.
"""
def __init__(self, iterator):
self.iterator = iterator
code_mock = NonCallableMock(spec_set=CodeType)
code_mock.co_flags = inspect.CO_ITERABLE_COROUTINE
self.__dict__['__code__'] = code_mock
def __aiter__(self):
return self
async def __anext__(self):
try:
return next(self.iterator)
except StopIteration:
pass
raise StopAsyncIteration
class _AwaitEvent:
def __init__(self, mock):
self._mock = mock
self._condition = None
async def _notify(self):
condition = self._get_condition()
try:
await condition.acquire()
condition.notify_all()
finally:
condition.release()
def _get_condition(self):
"""
Creation of condition is delayed, to minimize the chance of using the
wrong loop.
A user may create a mock with _AwaitEvent before selecting the
execution loop. Requiring a user to delay creation is error-prone and
inflexible. Instead, condition is created when user actually starts to
use the mock.
"""
# No synchronization is needed:
# - asyncio is thread unsafe
# - there are no awaits here, method will be executed without
# switching asyncio context.
if self._condition is None:
self._condition = asyncio.Condition()
return self._condition

View File

@ -0,0 +1,549 @@
import asyncio
import inspect
import unittest
from unittest.mock import call, AsyncMock, patch, MagicMock, create_autospec
def tearDownModule():
asyncio.set_event_loop_policy(None)
class AsyncClass:
def __init__(self):
pass
async def async_method(self):
pass
def normal_method(self):
pass
async def async_func():
pass
def normal_func():
pass
class NormalClass(object):
def a(self):
pass
async_foo_name = f'{__name__}.AsyncClass'
normal_foo_name = f'{__name__}.NormalClass'
class AsyncPatchDecoratorTest(unittest.TestCase):
def test_is_coroutine_function_patch(self):
@patch.object(AsyncClass, 'async_method')
def test_async(mock_method):
self.assertTrue(asyncio.iscoroutinefunction(mock_method))
test_async()
def test_is_async_patch(self):
@patch.object(AsyncClass, 'async_method')
def test_async(mock_method):
m = mock_method()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
@patch(f'{async_foo_name}.async_method')
def test_no_parent_attribute(mock_method):
m = mock_method()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
test_async()
test_no_parent_attribute()
def test_is_AsyncMock_patch(self):
@patch.object(AsyncClass, 'async_method')
def test_async(mock_method):
self.assertIsInstance(mock_method, AsyncMock)
test_async()
class AsyncPatchCMTest(unittest.TestCase):
def test_is_async_function_cm(self):
def test_async():
with patch.object(AsyncClass, 'async_method') as mock_method:
self.assertTrue(asyncio.iscoroutinefunction(mock_method))
test_async()
def test_is_async_cm(self):
def test_async():
with patch.object(AsyncClass, 'async_method') as mock_method:
m = mock_method()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
test_async()
def test_is_AsyncMock_cm(self):
def test_async():
with patch.object(AsyncClass, 'async_method') as mock_method:
self.assertIsInstance(mock_method, AsyncMock)
test_async()
class AsyncMockTest(unittest.TestCase):
def test_iscoroutinefunction_default(self):
mock = AsyncMock()
self.assertTrue(asyncio.iscoroutinefunction(mock))
def test_iscoroutinefunction_function(self):
async def foo(): pass
mock = AsyncMock(foo)
self.assertTrue(asyncio.iscoroutinefunction(mock))
self.assertTrue(inspect.iscoroutinefunction(mock))
def test_isawaitable(self):
mock = AsyncMock()
m = mock()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
self.assertIn('assert_awaited', dir(mock))
def test_iscoroutinefunction_normal_function(self):
def foo(): pass
mock = AsyncMock(foo)
self.assertTrue(asyncio.iscoroutinefunction(mock))
self.assertTrue(inspect.iscoroutinefunction(mock))
def test_future_isfuture(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
fut = asyncio.Future()
loop.stop()
loop.close()
mock = AsyncMock(fut)
self.assertIsInstance(mock, asyncio.Future)
class AsyncAutospecTest(unittest.TestCase):
def test_is_AsyncMock_patch(self):
@patch(async_foo_name, autospec=True)
def test_async(mock_method):
self.assertIsInstance(mock_method.async_method, AsyncMock)
self.assertIsInstance(mock_method, MagicMock)
@patch(async_foo_name, autospec=True)
def test_normal_method(mock_method):
self.assertIsInstance(mock_method.normal_method, MagicMock)
test_async()
test_normal_method()
def test_create_autospec_instance(self):
with self.assertRaises(RuntimeError):
create_autospec(async_func, instance=True)
def test_create_autospec(self):
spec = create_autospec(async_func)
self.assertTrue(asyncio.iscoroutinefunction(spec))
class AsyncSpecTest(unittest.TestCase):
def test_spec_as_async_positional_magicmock(self):
mock = MagicMock(async_func)
self.assertIsInstance(mock, MagicMock)
m = mock()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
def test_spec_as_async_kw_magicmock(self):
mock = MagicMock(spec=async_func)
self.assertIsInstance(mock, MagicMock)
m = mock()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
def test_spec_as_async_kw_AsyncMock(self):
mock = AsyncMock(spec=async_func)
self.assertIsInstance(mock, AsyncMock)
m = mock()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
def test_spec_as_async_positional_AsyncMock(self):
mock = AsyncMock(async_func)
self.assertIsInstance(mock, AsyncMock)
m = mock()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
def test_spec_as_normal_kw_AsyncMock(self):
mock = AsyncMock(spec=normal_func)
self.assertIsInstance(mock, AsyncMock)
m = mock()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
def test_spec_as_normal_positional_AsyncMock(self):
mock = AsyncMock(normal_func)
self.assertIsInstance(mock, AsyncMock)
m = mock()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
def test_spec_async_mock(self):
@patch.object(AsyncClass, 'async_method', spec=True)
def test_async(mock_method):
self.assertIsInstance(mock_method, AsyncMock)
test_async()
def test_spec_parent_not_async_attribute_is(self):
@patch(async_foo_name, spec=True)
def test_async(mock_method):
self.assertIsInstance(mock_method, MagicMock)
self.assertIsInstance(mock_method.async_method, AsyncMock)
test_async()
def test_target_async_spec_not(self):
@patch.object(AsyncClass, 'async_method', spec=NormalClass.a)
def test_async_attribute(mock_method):
self.assertIsInstance(mock_method, MagicMock)
self.assertFalse(inspect.iscoroutine(mock_method))
self.assertFalse(inspect.isawaitable(mock_method))
test_async_attribute()
def test_target_not_async_spec_is(self):
@patch.object(NormalClass, 'a', spec=async_func)
def test_attribute_not_async_spec_is(mock_async_func):
self.assertIsInstance(mock_async_func, AsyncMock)
test_attribute_not_async_spec_is()
def test_spec_async_attributes(self):
@patch(normal_foo_name, spec=AsyncClass)
def test_async_attributes_coroutines(MockNormalClass):
self.assertIsInstance(MockNormalClass.async_method, AsyncMock)
self.assertIsInstance(MockNormalClass, MagicMock)
test_async_attributes_coroutines()
class AsyncSpecSetTest(unittest.TestCase):
def test_is_AsyncMock_patch(self):
@patch.object(AsyncClass, 'async_method', spec_set=True)
def test_async(async_method):
self.assertIsInstance(async_method, AsyncMock)
def test_is_async_AsyncMock(self):
mock = AsyncMock(spec_set=AsyncClass.async_method)
self.assertTrue(asyncio.iscoroutinefunction(mock))
self.assertIsInstance(mock, AsyncMock)
def test_is_child_AsyncMock(self):
mock = MagicMock(spec_set=AsyncClass)
self.assertTrue(asyncio.iscoroutinefunction(mock.async_method))
self.assertFalse(asyncio.iscoroutinefunction(mock.normal_method))
self.assertIsInstance(mock.async_method, AsyncMock)
self.assertIsInstance(mock.normal_method, MagicMock)
self.assertIsInstance(mock, MagicMock)
class AsyncArguments(unittest.TestCase):
def test_add_return_value(self):
async def addition(self, var):
return var + 1
mock = AsyncMock(addition, return_value=10)
output = asyncio.run(mock(5))
self.assertEqual(output, 10)
def test_add_side_effect_exception(self):
async def addition(var):
return var + 1
mock = AsyncMock(addition, side_effect=Exception('err'))
with self.assertRaises(Exception):
asyncio.run(mock(5))
def test_add_side_effect_function(self):
async def addition(var):
return var + 1
mock = AsyncMock(side_effect=addition)
result = asyncio.run(mock(5))
self.assertEqual(result, 6)
def test_add_side_effect_iterable(self):
vals = [1, 2, 3]
mock = AsyncMock(side_effect=vals)
for item in vals:
self.assertEqual(item, asyncio.run(mock()))
with self.assertRaises(RuntimeError) as e:
asyncio.run(mock())
self.assertEqual(
e.exception,
RuntimeError('coroutine raised StopIteration')
)
class AsyncContextManagerTest(unittest.TestCase):
class WithAsyncContextManager:
def __init__(self):
self.entered = False
self.exited = False
async def __aenter__(self, *args, **kwargs):
self.entered = True
return self
async def __aexit__(self, *args, **kwargs):
self.exited = True
def test_magic_methods_are_async_mocks(self):
mock = MagicMock(self.WithAsyncContextManager())
self.assertIsInstance(mock.__aenter__, AsyncMock)
self.assertIsInstance(mock.__aexit__, AsyncMock)
def test_mock_supports_async_context_manager(self):
called = False
instance = self.WithAsyncContextManager()
mock_instance = MagicMock(instance)
async def use_context_manager():
nonlocal called
async with mock_instance as result:
called = True
return result
result = asyncio.run(use_context_manager())
self.assertFalse(instance.entered)
self.assertFalse(instance.exited)
self.assertTrue(called)
self.assertTrue(mock_instance.entered)
self.assertTrue(mock_instance.exited)
self.assertTrue(mock_instance.__aenter__.called)
self.assertTrue(mock_instance.__aexit__.called)
self.assertIsNot(mock_instance, result)
self.assertIsInstance(result, AsyncMock)
def test_mock_customize_async_context_manager(self):
instance = self.WithAsyncContextManager()
mock_instance = MagicMock(instance)
expected_result = object()
mock_instance.__aenter__.return_value = expected_result
async def use_context_manager():
async with mock_instance as result:
return result
self.assertIs(asyncio.run(use_context_manager()), expected_result)
def test_mock_customize_async_context_manager_with_coroutine(self):
enter_called = False
exit_called = False
async def enter_coroutine(*args):
nonlocal enter_called
enter_called = True
async def exit_coroutine(*args):
nonlocal exit_called
exit_called = True
instance = self.WithAsyncContextManager()
mock_instance = MagicMock(instance)
mock_instance.__aenter__ = enter_coroutine
mock_instance.__aexit__ = exit_coroutine
async def use_context_manager():
async with mock_instance:
pass
asyncio.run(use_context_manager())
self.assertTrue(enter_called)
self.assertTrue(exit_called)
def test_context_manager_raise_exception_by_default(self):
async def raise_in(context_manager):
async with context_manager:
raise TypeError()
instance = self.WithAsyncContextManager()
mock_instance = MagicMock(instance)
with self.assertRaises(TypeError):
asyncio.run(raise_in(mock_instance))
class AsyncIteratorTest(unittest.TestCase):
class WithAsyncIterator(object):
def __init__(self):
self.items = ["foo", "NormalFoo", "baz"]
def __aiter__(self):
return self
async def __anext__(self):
try:
return self.items.pop()
except IndexError:
pass
raise StopAsyncIteration
def test_mock_aiter_and_anext(self):
instance = self.WithAsyncIterator()
mock_instance = MagicMock(instance)
self.assertEqual(asyncio.iscoroutine(instance.__aiter__),
asyncio.iscoroutine(mock_instance.__aiter__))
self.assertEqual(asyncio.iscoroutine(instance.__anext__),
asyncio.iscoroutine(mock_instance.__anext__))
iterator = instance.__aiter__()
if asyncio.iscoroutine(iterator):
iterator = asyncio.run(iterator)
mock_iterator = mock_instance.__aiter__()
if asyncio.iscoroutine(mock_iterator):
mock_iterator = asyncio.run(mock_iterator)
self.assertEqual(asyncio.iscoroutine(iterator.__aiter__),
asyncio.iscoroutine(mock_iterator.__aiter__))
self.assertEqual(asyncio.iscoroutine(iterator.__anext__),
asyncio.iscoroutine(mock_iterator.__anext__))
def test_mock_async_for(self):
async def iterate(iterator):
accumulator = []
async for item in iterator:
accumulator.append(item)
return accumulator
expected = ["FOO", "BAR", "BAZ"]
with self.subTest("iterate through default value"):
mock_instance = MagicMock(self.WithAsyncIterator())
self.assertEqual([], asyncio.run(iterate(mock_instance)))
with self.subTest("iterate through set return_value"):
mock_instance = MagicMock(self.WithAsyncIterator())
mock_instance.__aiter__.return_value = expected[:]
self.assertEqual(expected, asyncio.run(iterate(mock_instance)))
with self.subTest("iterate through set return_value iterator"):
mock_instance = MagicMock(self.WithAsyncIterator())
mock_instance.__aiter__.return_value = iter(expected[:])
self.assertEqual(expected, asyncio.run(iterate(mock_instance)))
class AsyncMockAssert(unittest.TestCase):
def setUp(self):
self.mock = AsyncMock()
async def _runnable_test(self, *args):
if not args:
await self.mock()
else:
await self.mock(*args)
def test_assert_awaited(self):
with self.assertRaises(AssertionError):
self.mock.assert_awaited()
asyncio.run(self._runnable_test())
self.mock.assert_awaited()
def test_assert_awaited_once(self):
with self.assertRaises(AssertionError):
self.mock.assert_awaited_once()
asyncio.run(self._runnable_test())
self.mock.assert_awaited_once()
asyncio.run(self._runnable_test())
with self.assertRaises(AssertionError):
self.mock.assert_awaited_once()
def test_assert_awaited_with(self):
asyncio.run(self._runnable_test())
with self.assertRaises(AssertionError):
self.mock.assert_awaited_with('foo')
asyncio.run(self._runnable_test('foo'))
self.mock.assert_awaited_with('foo')
asyncio.run(self._runnable_test('SomethingElse'))
with self.assertRaises(AssertionError):
self.mock.assert_awaited_with('foo')
def test_assert_awaited_once_with(self):
with self.assertRaises(AssertionError):
self.mock.assert_awaited_once_with('foo')
asyncio.run(self._runnable_test('foo'))
self.mock.assert_awaited_once_with('foo')
asyncio.run(self._runnable_test('foo'))
with self.assertRaises(AssertionError):
self.mock.assert_awaited_once_with('foo')
def test_assert_any_wait(self):
with self.assertRaises(AssertionError):
self.mock.assert_any_await('NormalFoo')
asyncio.run(self._runnable_test('foo'))
with self.assertRaises(AssertionError):
self.mock.assert_any_await('NormalFoo')
asyncio.run(self._runnable_test('NormalFoo'))
self.mock.assert_any_await('NormalFoo')
asyncio.run(self._runnable_test('SomethingElse'))
self.mock.assert_any_await('NormalFoo')
def test_assert_has_awaits_no_order(self):
calls = [call('NormalFoo'), call('baz')]
with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls)
asyncio.run(self._runnable_test('foo'))
with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls)
asyncio.run(self._runnable_test('NormalFoo'))
with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls)
asyncio.run(self._runnable_test('baz'))
self.mock.assert_has_awaits(calls)
asyncio.run(self._runnable_test('SomethingElse'))
self.mock.assert_has_awaits(calls)
def test_assert_has_awaits_ordered(self):
calls = [call('NormalFoo'), call('baz')]
with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls, any_order=True)
asyncio.run(self._runnable_test('baz'))
with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls, any_order=True)
asyncio.run(self._runnable_test('foo'))
with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls, any_order=True)
asyncio.run(self._runnable_test('NormalFoo'))
self.mock.assert_has_awaits(calls, any_order=True)
asyncio.run(self._runnable_test('qux'))
self.mock.assert_has_awaits(calls, any_order=True)
def test_assert_not_awaited(self):
self.mock.assert_not_awaited()
asyncio.run(self._runnable_test())
with self.assertRaises(AssertionError):
self.mock.assert_not_awaited()

View File

@ -9,7 +9,7 @@ from unittest import mock
from unittest.mock import (
call, DEFAULT, patch, sentinel,
MagicMock, Mock, NonCallableMock,
NonCallableMagicMock, _Call, _CallList,
NonCallableMagicMock, AsyncMock, _Call, _CallList,
create_autospec
)
@ -1618,7 +1618,8 @@ class MockTest(unittest.TestCase):
def test_adding_child_mock(self):
for Klass in NonCallableMock, Mock, MagicMock, NonCallableMagicMock:
for Klass in (NonCallableMock, Mock, MagicMock, NonCallableMagicMock,
AsyncMock):
mock = Klass()
mock.foo = Mock()

View File

@ -0,0 +1,2 @@
Added AsyncMock to support using unittest to mock asyncio coroutines.
Patch by Lisa Roach.