diff --git a/Lib/unittest/test/testmock/testasync.py b/Lib/unittest/test/testmock/testasync.py index 149fd4deff1..73d31a29668 100644 --- a/Lib/unittest/test/testmock/testasync.py +++ b/Lib/unittest/test/testmock/testasync.py @@ -72,9 +72,17 @@ class AsyncPatchDecoratorTest(unittest.TestCase): test_async() def test_async_def_patch(self): - @patch(f"{__name__}.async_func", AsyncMock()) - async def test_async(): + @patch(f"{__name__}.async_func", return_value=1) + @patch(f"{__name__}.async_func_args", return_value=2) + async def test_async(func_args_mock, func_mock): + self.assertEqual(func_args_mock._mock_name, "async_func_args") + self.assertEqual(func_mock._mock_name, "async_func") + self.assertIsInstance(async_func, AsyncMock) + self.assertIsInstance(async_func_args, AsyncMock) + + self.assertEqual(await async_func(), 1) + self.assertEqual(await async_func_args(1, 2, c=3), 2) asyncio.run(test_async()) self.assertTrue(inspect.iscoroutinefunction(async_func)) @@ -375,22 +383,40 @@ class AsyncArguments(unittest.IsolatedAsyncioTestCase): with self.assertRaises(Exception): await mock(5) - async def test_add_side_effect_function(self): + async def test_add_side_effect_coroutine(self): async def addition(var): return var + 1 mock = AsyncMock(side_effect=addition) result = await mock(5) self.assertEqual(result, 6) + async def test_add_side_effect_normal_function(self): + def addition(var): + return var + 1 + mock = AsyncMock(side_effect=addition) + result = await mock(5) + self.assertEqual(result, 6) + async def test_add_side_effect_iterable(self): vals = [1, 2, 3] mock = AsyncMock(side_effect=vals) for item in vals: - self.assertEqual(item, await mock()) + self.assertEqual(await mock(), item) with self.assertRaises(StopAsyncIteration) as e: await mock() + async def test_add_side_effect_exception_iterable(self): + class SampleException(Exception): + pass + + vals = [1, SampleException("foo")] + mock = AsyncMock(side_effect=vals) + self.assertEqual(await mock(), 1) + + with self.assertRaises(SampleException) as e: + await mock() + async def test_return_value_AsyncMock(self): value = AsyncMock(return_value=10) mock = AsyncMock(return_value=value) @@ -437,6 +463,21 @@ class AsyncArguments(unittest.IsolatedAsyncioTestCase): mock.assert_awaited() self.assertTrue(ran) + async def test_wraps_normal_function(self): + value = 1 + + ran = False + def inner(): + nonlocal ran + ran = True + return value + + mock = AsyncMock(wraps=inner) + result = await mock() + self.assertEqual(result, value) + mock.assert_awaited() + self.assertTrue(ran) + class AsyncMagicMethods(unittest.TestCase): def test_async_magic_methods_return_async_mocks(self): m_mock = MagicMock() @@ -860,6 +901,10 @@ class AsyncMockAssert(unittest.TestCase): self.mock.assert_awaited_once() def test_assert_awaited_with(self): + msg = 'Not awaited' + with self.assertRaisesRegex(AssertionError, msg): + self.mock.assert_awaited_with('foo') + asyncio.run(self._runnable_test()) msg = 'expected await not found' with self.assertRaisesRegex(AssertionError, msg):