@@ -72,9 +72,17 @@ def test_async(mock_method):
72
72
test_async ()
73
73
74
74
def test_async_def_patch (self ):
75
- @patch (f"{ __name__ } .async_func" , AsyncMock ())
76
- async def test_async ():
75
+ @patch (f"{ __name__ } .async_func" , return_value = 1 )
76
+ @patch (f"{ __name__ } .async_func_args" , return_value = 2 )
77
+ async def test_async (func_args_mock , func_mock ):
78
+ self .assertEqual (func_args_mock ._mock_name , "async_func_args" )
79
+ self .assertEqual (func_mock ._mock_name , "async_func" )
80
+
77
81
self .assertIsInstance (async_func , AsyncMock )
82
+ self .assertIsInstance (async_func_args , AsyncMock )
83
+
84
+ self .assertEqual (await async_func (), 1 )
85
+ self .assertEqual (await async_func_args (1 , 2 , c = 3 ), 2 )
78
86
79
87
asyncio .run (test_async ())
80
88
self .assertTrue (inspect .iscoroutinefunction (async_func ))
@@ -370,22 +378,40 @@ async def addition(var):
370
378
with self .assertRaises (Exception ):
371
379
await mock (5 )
372
380
373
- async def test_add_side_effect_function (self ):
381
+ async def test_add_side_effect_coroutine (self ):
374
382
async def addition (var ):
375
383
return var + 1
376
384
mock = AsyncMock (side_effect = addition )
377
385
result = await mock (5 )
378
386
self .assertEqual (result , 6 )
379
387
388
+ async def test_add_side_effect_normal_function (self ):
389
+ def addition (var ):
390
+ return var + 1
391
+ mock = AsyncMock (side_effect = addition )
392
+ result = await mock (5 )
393
+ self .assertEqual (result , 6 )
394
+
380
395
async def test_add_side_effect_iterable (self ):
381
396
vals = [1 , 2 , 3 ]
382
397
mock = AsyncMock (side_effect = vals )
383
398
for item in vals :
384
- self .assertEqual (item , await mock ())
399
+ self .assertEqual (await mock (), item )
385
400
386
401
with self .assertRaises (StopAsyncIteration ) as e :
387
402
await mock ()
388
403
404
+ async def test_add_side_effect_exception_iterable (self ):
405
+ class SampleException (Exception ):
406
+ pass
407
+
408
+ vals = [1 , SampleException ("foo" )]
409
+ mock = AsyncMock (side_effect = vals )
410
+ self .assertEqual (await mock (), 1 )
411
+
412
+ with self .assertRaises (SampleException ) as e :
413
+ await mock ()
414
+
389
415
async def test_return_value_AsyncMock (self ):
390
416
value = AsyncMock (return_value = 10 )
391
417
mock = AsyncMock (return_value = value )
@@ -432,6 +458,21 @@ async def inner():
432
458
mock .assert_awaited ()
433
459
self .assertTrue (ran )
434
460
461
+ async def test_wraps_normal_function (self ):
462
+ value = 1
463
+
464
+ ran = False
465
+ def inner ():
466
+ nonlocal ran
467
+ ran = True
468
+ return value
469
+
470
+ mock = AsyncMock (wraps = inner )
471
+ result = await mock ()
472
+ self .assertEqual (result , value )
473
+ mock .assert_awaited ()
474
+ self .assertTrue (ran )
475
+
435
476
class AsyncMagicMethods (unittest .TestCase ):
436
477
def test_async_magic_methods_return_async_mocks (self ):
437
478
m_mock = MagicMock ()
@@ -854,6 +895,10 @@ def test_assert_awaited_once(self):
854
895
self .mock .assert_awaited_once ()
855
896
856
897
def test_assert_awaited_with (self ):
898
+ msg = 'Not awaited'
899
+ with self .assertRaisesRegex (AssertionError , msg ):
900
+ self .mock .assert_awaited_with ('foo' )
901
+
857
902
asyncio .run (self ._runnable_test ())
858
903
msg = 'expected await not found'
859
904
with self .assertRaisesRegex (AssertionError , msg ):
0 commit comments