bpo-37555: Ensure all assert methods using _call_matcher are actually passing calls

This commit is contained in:
Elizabeth Uselton 2019-08-05 00:51:24 -07:00
parent f47699de12
commit 38650c98c6
1 changed files with 25 additions and 9 deletions

View File

@ -864,9 +864,9 @@ class NonCallableMock(Base):
def _error_message():
msg = self._format_mock_failure_message(args, kwargs)
return msg
expected = self._call_matcher((args, kwargs))
expected = self._call_matcher(_Call((args, kwargs)))
actual = self._call_matcher(self.call_args)
if expected != actual:
if actual != expected:
cause = expected if isinstance(expected, Exception) else None
raise AssertionError(_error_message()) from cause
@ -926,10 +926,10 @@ class NonCallableMock(Base):
The assert passes if the mock has *ever* been called, unlike
`assert_called_with` and `assert_called_once_with` that only pass if
the call is the most recent one."""
expected = self._call_matcher((args, kwargs))
expected = self._call_matcher(_Call((args, kwargs), two=True))
cause = expected if isinstance(expected, Exception) else None
actual = [self._call_matcher(c) for c in self.call_args_list]
if expected not in actual:
cause = expected if isinstance(expected, Exception) else None
if cause or expected not in _AnyComparer(actual):
expected_string = self._format_mock_call_signature(args, kwargs)
raise AssertionError(
'%s call not found' % expected_string
@ -982,6 +982,22 @@ class NonCallableMock(Base):
return f"\n{prefix}: {safe_repr(self.mock_calls)}."
class _AnyComparer(list):
"""A list which checks if it contains a call which may have an
argument of ANY, flipping the components of item and self from
their traditional locations so that ANY is guaranteed to be on
the left."""
def __contains__(self, item):
for _call in self:
if len(item) != len(_call):
continue
if all([
expected == actual
for expected, actual in zip(item, _call)
]):
return True
return False
def _try_iter(obj):
if obj is None:
@ -2133,9 +2149,9 @@ class AsyncMockMixin(Base):
msg = self._format_mock_failure_message(args, kwargs, action='await')
return msg
expected = self._call_matcher((args, kwargs))
expected = self._call_matcher(_Call((args, kwargs), two=True))
actual = self._call_matcher(self.await_args)
if expected != actual:
if actual != expected:
cause = expected if isinstance(expected, Exception) else None
raise AssertionError(_error_message()) from cause
@ -2154,9 +2170,9 @@ class AsyncMockMixin(Base):
"""
Assert the mock has ever been awaited with the specified arguments.
"""
expected = self._call_matcher((args, kwargs))
expected = self._call_matcher(_Call((args, kwargs), two=True))
actual = [self._call_matcher(c) for c in self.await_args_list]
if expected not in actual:
if expected not in _AnyComparer(actual):
cause = expected if isinstance(expected, Exception) else None
expected_string = self._format_mock_call_signature(args, kwargs)
raise AssertionError(