Skip to content

Commit 0ea7309

Browse files
Improve test coverage for AsyncMock. (GH-17906)
* Add test for nested async decorator patch. * Add test for side_effect and wraps with a function. * Add test for side_effect with an exception in the iterable. (cherry picked from commit 54f743e) Co-authored-by: Karthikeyan Singaravelan <[email protected]>
1 parent a46728a commit 0ea7309

File tree

1 file changed

+49
-4
lines changed

1 file changed

+49
-4
lines changed

Lib/unittest/test/testmock/testasync.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,17 @@ def test_async(mock_method):
7272
test_async()
7373

7474
def test_async_def_patch(self):
75-
@patch(f"{__name__}.async_func", AsyncMock())
76-
async def test_async():
75+
@patch(f"{__name__}.async_func", return_value=1)
76+
@patch(f"{__name__}.async_func_args", return_value=2)
77+
async def test_async(func_args_mock, func_mock):
78+
self.assertEqual(func_args_mock._mock_name, "async_func_args")
79+
self.assertEqual(func_mock._mock_name, "async_func")
80+
7781
self.assertIsInstance(async_func, AsyncMock)
82+
self.assertIsInstance(async_func_args, AsyncMock)
83+
84+
self.assertEqual(await async_func(), 1)
85+
self.assertEqual(await async_func_args(1, 2, c=3), 2)
7886

7987
asyncio.run(test_async())
8088
self.assertTrue(inspect.iscoroutinefunction(async_func))
@@ -370,22 +378,40 @@ async def addition(var):
370378
with self.assertRaises(Exception):
371379
await mock(5)
372380

373-
async def test_add_side_effect_function(self):
381+
async def test_add_side_effect_coroutine(self):
374382
async def addition(var):
375383
return var + 1
376384
mock = AsyncMock(side_effect=addition)
377385
result = await mock(5)
378386
self.assertEqual(result, 6)
379387

388+
async def test_add_side_effect_normal_function(self):
389+
def addition(var):
390+
return var + 1
391+
mock = AsyncMock(side_effect=addition)
392+
result = await mock(5)
393+
self.assertEqual(result, 6)
394+
380395
async def test_add_side_effect_iterable(self):
381396
vals = [1, 2, 3]
382397
mock = AsyncMock(side_effect=vals)
383398
for item in vals:
384-
self.assertEqual(item, await mock())
399+
self.assertEqual(await mock(), item)
385400

386401
with self.assertRaises(StopAsyncIteration) as e:
387402
await mock()
388403

404+
async def test_add_side_effect_exception_iterable(self):
405+
class SampleException(Exception):
406+
pass
407+
408+
vals = [1, SampleException("foo")]
409+
mock = AsyncMock(side_effect=vals)
410+
self.assertEqual(await mock(), 1)
411+
412+
with self.assertRaises(SampleException) as e:
413+
await mock()
414+
389415
async def test_return_value_AsyncMock(self):
390416
value = AsyncMock(return_value=10)
391417
mock = AsyncMock(return_value=value)
@@ -432,6 +458,21 @@ async def inner():
432458
mock.assert_awaited()
433459
self.assertTrue(ran)
434460

461+
async def test_wraps_normal_function(self):
462+
value = 1
463+
464+
ran = False
465+
def inner():
466+
nonlocal ran
467+
ran = True
468+
return value
469+
470+
mock = AsyncMock(wraps=inner)
471+
result = await mock()
472+
self.assertEqual(result, value)
473+
mock.assert_awaited()
474+
self.assertTrue(ran)
475+
435476
class AsyncMagicMethods(unittest.TestCase):
436477
def test_async_magic_methods_return_async_mocks(self):
437478
m_mock = MagicMock()
@@ -854,6 +895,10 @@ def test_assert_awaited_once(self):
854895
self.mock.assert_awaited_once()
855896

856897
def test_assert_awaited_with(self):
898+
msg = 'Not awaited'
899+
with self.assertRaisesRegex(AssertionError, msg):
900+
self.mock.assert_awaited_with('foo')
901+
857902
asyncio.run(self._runnable_test())
858903
msg = 'expected await not found'
859904
with self.assertRaisesRegex(AssertionError, msg):

0 commit comments

Comments
 (0)