mirror of https://github.com/python/cpython
bpo-36996: Handle async functions when mock.patch is used as a decorator (GH-13562)
Return a coroutine while patching async functions with a decorator. Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com> https://bugs.python.org/issue36996
This commit is contained in:
parent
71dc7c5fbd
commit
436c2b0d67
|
@ -26,6 +26,7 @@ __all__ = (
|
|||
__version__ = '1.0'
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import io
|
||||
import inspect
|
||||
import pprint
|
||||
|
@ -1220,6 +1221,8 @@ class _patch(object):
|
|||
def __call__(self, func):
|
||||
if isinstance(func, type):
|
||||
return self.decorate_class(func)
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return self.decorate_async_callable(func)
|
||||
return self.decorate_callable(func)
|
||||
|
||||
|
||||
|
@ -1237,41 +1240,68 @@ class _patch(object):
|
|||
return klass
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def decoration_helper(self, patched, args, keywargs):
|
||||
extra_args = []
|
||||
entered_patchers = []
|
||||
patching = None
|
||||
|
||||
exc_info = tuple()
|
||||
try:
|
||||
for patching in patched.patchings:
|
||||
arg = patching.__enter__()
|
||||
entered_patchers.append(patching)
|
||||
if patching.attribute_name is not None:
|
||||
keywargs.update(arg)
|
||||
elif patching.new is DEFAULT:
|
||||
extra_args.append(arg)
|
||||
|
||||
args += tuple(extra_args)
|
||||
yield (args, keywargs)
|
||||
except:
|
||||
if (patching not in entered_patchers and
|
||||
_is_started(patching)):
|
||||
# the patcher may have been started, but an exception
|
||||
# raised whilst entering one of its additional_patchers
|
||||
entered_patchers.append(patching)
|
||||
# Pass the exception to __exit__
|
||||
exc_info = sys.exc_info()
|
||||
# re-raise the exception
|
||||
raise
|
||||
finally:
|
||||
for patching in reversed(entered_patchers):
|
||||
patching.__exit__(*exc_info)
|
||||
|
||||
|
||||
def decorate_callable(self, func):
|
||||
# NB. Keep the method in sync with decorate_async_callable()
|
||||
if hasattr(func, 'patchings'):
|
||||
func.patchings.append(self)
|
||||
return func
|
||||
|
||||
@wraps(func)
|
||||
def patched(*args, **keywargs):
|
||||
extra_args = []
|
||||
entered_patchers = []
|
||||
with self.decoration_helper(patched,
|
||||
args,
|
||||
keywargs) as (newargs, newkeywargs):
|
||||
return func(*newargs, **newkeywargs)
|
||||
|
||||
exc_info = tuple()
|
||||
try:
|
||||
for patching in patched.patchings:
|
||||
arg = patching.__enter__()
|
||||
entered_patchers.append(patching)
|
||||
if patching.attribute_name is not None:
|
||||
keywargs.update(arg)
|
||||
elif patching.new is DEFAULT:
|
||||
extra_args.append(arg)
|
||||
patched.patchings = [self]
|
||||
return patched
|
||||
|
||||
args += tuple(extra_args)
|
||||
return func(*args, **keywargs)
|
||||
except:
|
||||
if (patching not in entered_patchers and
|
||||
_is_started(patching)):
|
||||
# the patcher may have been started, but an exception
|
||||
# raised whilst entering one of its additional_patchers
|
||||
entered_patchers.append(patching)
|
||||
# Pass the exception to __exit__
|
||||
exc_info = sys.exc_info()
|
||||
# re-raise the exception
|
||||
raise
|
||||
finally:
|
||||
for patching in reversed(entered_patchers):
|
||||
patching.__exit__(*exc_info)
|
||||
|
||||
def decorate_async_callable(self, func):
|
||||
# NB. Keep the method in sync with decorate_callable()
|
||||
if hasattr(func, 'patchings'):
|
||||
func.patchings.append(self)
|
||||
return func
|
||||
|
||||
@wraps(func)
|
||||
async def patched(*args, **keywargs):
|
||||
with self.decoration_helper(patched,
|
||||
args,
|
||||
keywargs) as (newargs, newkeywargs):
|
||||
return await func(*newargs, **newkeywargs)
|
||||
|
||||
patched.patchings = [self]
|
||||
return patched
|
||||
|
|
|
@ -66,6 +66,14 @@ class AsyncPatchDecoratorTest(unittest.TestCase):
|
|||
|
||||
test_async()
|
||||
|
||||
def test_async_def_patch(self):
|
||||
@patch(f"{__name__}.async_func", AsyncMock())
|
||||
async def test_async():
|
||||
self.assertIsInstance(async_func, AsyncMock)
|
||||
|
||||
asyncio.run(test_async())
|
||||
self.assertTrue(inspect.iscoroutinefunction(async_func))
|
||||
|
||||
|
||||
class AsyncPatchCMTest(unittest.TestCase):
|
||||
def test_is_async_function_cm(self):
|
||||
|
@ -91,6 +99,14 @@ class AsyncPatchCMTest(unittest.TestCase):
|
|||
|
||||
test_async()
|
||||
|
||||
def test_async_def_cm(self):
|
||||
async def test_async():
|
||||
with patch(f"{__name__}.async_func", AsyncMock()):
|
||||
self.assertIsInstance(async_func, AsyncMock)
|
||||
self.assertTrue(inspect.iscoroutinefunction(async_func))
|
||||
|
||||
asyncio.run(test_async())
|
||||
|
||||
|
||||
class AsyncMockTest(unittest.TestCase):
|
||||
def test_iscoroutinefunction_default(self):
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Handle :func:`unittest.mock.patch` used as a decorator on async functions.
|
Loading…
Reference in New Issue