From 436c2b0d67da68465e709a96daac7340af3a5238 Mon Sep 17 00:00:00 2001 From: Xtreak Date: Tue, 28 May 2019 12:37:39 +0530 Subject: [PATCH] bpo-36996: Handle async functions when mock.patch is used as a decorator (GH-13562) Return a coroutine while patching async functions with a decorator. Co-authored-by: Andrew Svetlov https://bugs.python.org/issue36996 --- Lib/unittest/mock.py | 82 +++++++++++++------ Lib/unittest/test/testmock/testasync.py | 16 ++++ .../2019-05-22-22-55-18.bpo-36996.XQx08d.rst | 1 + 3 files changed, 73 insertions(+), 26 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2019-05-22-22-55-18.bpo-36996.XQx08d.rst diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py index b91afd88dd1..fac4535747c 100644 --- a/Lib/unittest/mock.py +++ b/Lib/unittest/mock.py @@ -26,6 +26,7 @@ __all__ = ( __version__ = '1.0' import asyncio +import contextlib import io import inspect import pprint @@ -1220,6 +1221,8 @@ class _patch(object): def __call__(self, func): if isinstance(func, type): return self.decorate_class(func) + if inspect.iscoroutinefunction(func): + return self.decorate_async_callable(func) return self.decorate_callable(func) @@ -1237,41 +1240,68 @@ class _patch(object): return klass + @contextlib.contextmanager + def decoration_helper(self, patched, args, keywargs): + extra_args = [] + entered_patchers = [] + patching = None + + exc_info = tuple() + try: + for patching in patched.patchings: + arg = patching.__enter__() + entered_patchers.append(patching) + if patching.attribute_name is not None: + keywargs.update(arg) + elif patching.new is DEFAULT: + extra_args.append(arg) + + args += tuple(extra_args) + yield (args, keywargs) + except: + if (patching not in entered_patchers and + _is_started(patching)): + # the patcher may have been started, but an exception + # raised whilst entering one of its additional_patchers + entered_patchers.append(patching) + # Pass the exception to __exit__ + exc_info = sys.exc_info() + # re-raise the exception + raise + finally: + for patching in reversed(entered_patchers): + patching.__exit__(*exc_info) + + def decorate_callable(self, func): + # NB. Keep the method in sync with decorate_async_callable() if hasattr(func, 'patchings'): func.patchings.append(self) return func @wraps(func) def patched(*args, **keywargs): - extra_args = [] - entered_patchers = [] + with self.decoration_helper(patched, + args, + keywargs) as (newargs, newkeywargs): + return func(*newargs, **newkeywargs) - exc_info = tuple() - try: - for patching in patched.patchings: - arg = patching.__enter__() - entered_patchers.append(patching) - if patching.attribute_name is not None: - keywargs.update(arg) - elif patching.new is DEFAULT: - extra_args.append(arg) + patched.patchings = [self] + return patched - args += tuple(extra_args) - return func(*args, **keywargs) - except: - if (patching not in entered_patchers and - _is_started(patching)): - # the patcher may have been started, but an exception - # raised whilst entering one of its additional_patchers - entered_patchers.append(patching) - # Pass the exception to __exit__ - exc_info = sys.exc_info() - # re-raise the exception - raise - finally: - for patching in reversed(entered_patchers): - patching.__exit__(*exc_info) + + def decorate_async_callable(self, func): + # NB. Keep the method in sync with decorate_callable() + if hasattr(func, 'patchings'): + func.patchings.append(self) + return func + + @wraps(func) + async def patched(*args, **keywargs): + with self.decoration_helper(patched, + args, + keywargs) as (newargs, newkeywargs): + return await func(*newargs, **newkeywargs) patched.patchings = [self] return patched diff --git a/Lib/unittest/test/testmock/testasync.py b/Lib/unittest/test/testmock/testasync.py index 0519d59696f..ccea4fe242d 100644 --- a/Lib/unittest/test/testmock/testasync.py +++ b/Lib/unittest/test/testmock/testasync.py @@ -66,6 +66,14 @@ class AsyncPatchDecoratorTest(unittest.TestCase): test_async() + def test_async_def_patch(self): + @patch(f"{__name__}.async_func", AsyncMock()) + async def test_async(): + self.assertIsInstance(async_func, AsyncMock) + + asyncio.run(test_async()) + self.assertTrue(inspect.iscoroutinefunction(async_func)) + class AsyncPatchCMTest(unittest.TestCase): def test_is_async_function_cm(self): @@ -91,6 +99,14 @@ class AsyncPatchCMTest(unittest.TestCase): test_async() + def test_async_def_cm(self): + async def test_async(): + with patch(f"{__name__}.async_func", AsyncMock()): + self.assertIsInstance(async_func, AsyncMock) + self.assertTrue(inspect.iscoroutinefunction(async_func)) + + asyncio.run(test_async()) + class AsyncMockTest(unittest.TestCase): def test_iscoroutinefunction_default(self): diff --git a/Misc/NEWS.d/next/Library/2019-05-22-22-55-18.bpo-36996.XQx08d.rst b/Misc/NEWS.d/next/Library/2019-05-22-22-55-18.bpo-36996.XQx08d.rst new file mode 100644 index 00000000000..69d18d9713b --- /dev/null +++ b/Misc/NEWS.d/next/Library/2019-05-22-22-55-18.bpo-36996.XQx08d.rst @@ -0,0 +1 @@ +Handle :func:`unittest.mock.patch` used as a decorator on async functions.