3
3
import unittest
4
4
5
5
from unittest .mock import (ANY , call , AsyncMock , patch , MagicMock ,
6
- create_autospec , _AwaitEvent )
6
+ create_autospec , _AwaitEvent , sentinel , _CallList )
7
7
8
8
9
9
def tearDownModule ():
@@ -595,11 +595,173 @@ class AsyncMockAssert(unittest.TestCase):
595
595
def setUp (self ):
596
596
self .mock = AsyncMock ()
597
597
598
- async def _runnable_test (self , * args ):
599
- if not args :
600
- await self .mock ()
601
- else :
602
- await self .mock (* args )
598
+ async def _runnable_test (self , * args , ** kwargs ):
599
+ await self .mock (* args , ** kwargs )
600
+
601
+ async def _await_coroutine (self , coroutine ):
602
+ return await coroutine
603
+
604
+ def test_assert_called_but_not_awaited (self ):
605
+ mock = AsyncMock (AsyncClass )
606
+ with self .assertWarns (RuntimeWarning ):
607
+ # Will raise a warning because never awaited
608
+ mock .async_method ()
609
+ self .assertTrue (asyncio .iscoroutinefunction (mock .async_method ))
610
+ mock .async_method .assert_called ()
611
+ mock .async_method .assert_called_once ()
612
+ mock .async_method .assert_called_once_with ()
613
+ with self .assertRaises (AssertionError ):
614
+ mock .assert_awaited ()
615
+ with self .assertRaises (AssertionError ):
616
+ mock .async_method .assert_awaited ()
617
+
618
+ def test_assert_called_then_awaited (self ):
619
+ mock = AsyncMock (AsyncClass )
620
+ mock_coroutine = mock .async_method ()
621
+ mock .async_method .assert_called ()
622
+ mock .async_method .assert_called_once ()
623
+ mock .async_method .assert_called_once_with ()
624
+ with self .assertRaises (AssertionError ):
625
+ mock .async_method .assert_awaited ()
626
+
627
+ asyncio .run (self ._await_coroutine (mock_coroutine ))
628
+ # Assert we haven't re-called the function
629
+ mock .async_method .assert_called_once ()
630
+ mock .async_method .assert_awaited ()
631
+ mock .async_method .assert_awaited_once ()
632
+ mock .async_method .assert_awaited_once_with ()
633
+
634
+ def test_assert_called_and_awaited_at_same_time (self ):
635
+ with self .assertRaises (AssertionError ):
636
+ self .mock .assert_awaited ()
637
+
638
+ with self .assertRaises (AssertionError ):
639
+ self .mock .assert_called ()
640
+
641
+ asyncio .run (self ._runnable_test ())
642
+ self .mock .assert_called_once ()
643
+ self .mock .assert_awaited_once ()
644
+
645
+ def test_assert_called_twice_and_awaited_once (self ):
646
+ mock = AsyncMock (AsyncClass )
647
+ coroutine = mock .async_method ()
648
+ with self .assertWarns (RuntimeWarning ):
649
+ # The first call will be awaited so no warning there
650
+ # But this call will never get awaited, so it will warn here
651
+ mock .async_method ()
652
+ with self .assertRaises (AssertionError ):
653
+ mock .async_method .assert_awaited ()
654
+ mock .async_method .assert_called ()
655
+ asyncio .run (self ._await_coroutine (coroutine ))
656
+ mock .async_method .assert_awaited ()
657
+ mock .async_method .assert_awaited_once ()
658
+
659
+ def test_assert_called_once_and_awaited_twice (self ):
660
+ mock = AsyncMock (AsyncClass )
661
+ coroutine = mock .async_method ()
662
+ mock .async_method .assert_called_once ()
663
+ asyncio .run (self ._await_coroutine (coroutine ))
664
+ with self .assertRaises (RuntimeError ):
665
+ # Cannot reuse already awaited coroutine
666
+ asyncio .run (self ._await_coroutine (coroutine ))
667
+ mock .async_method .assert_awaited ()
668
+
669
+ def test_assert_awaited_but_not_called (self ):
670
+ with self .assertRaises (AssertionError ):
671
+ self .mock .assert_awaited ()
672
+ with self .assertRaises (AssertionError ):
673
+ self .mock .assert_called ()
674
+ with self .assertRaises (TypeError ):
675
+ # You cannot await an AsyncMock, it must be a coroutine
676
+ asyncio .run (self ._await_coroutine (self .mock ))
677
+
678
+ with self .assertRaises (AssertionError ):
679
+ self .mock .assert_awaited ()
680
+ with self .assertRaises (AssertionError ):
681
+ self .mock .assert_called ()
682
+
683
+ def test_assert_has_calls_not_awaits (self ):
684
+ kalls = [call ('foo' )]
685
+ with self .assertWarns (RuntimeWarning ):
686
+ # Will raise a warning because never awaited
687
+ self .mock ('foo' )
688
+ self .mock .assert_has_calls (kalls )
689
+ with self .assertRaises (AssertionError ):
690
+ self .mock .assert_has_awaits (kalls )
691
+
692
+ def test_assert_has_mock_calls_on_async_mock_no_spec (self ):
693
+ with self .assertWarns (RuntimeWarning ):
694
+ # Will raise a warning because never awaited
695
+ self .mock ()
696
+ kalls_empty = [('' , (), {})]
697
+ self .assertEqual (self .mock .mock_calls , kalls_empty )
698
+
699
+ with self .assertWarns (RuntimeWarning ):
700
+ # Will raise a warning because never awaited
701
+ self .mock ('foo' )
702
+ self .mock ('baz' )
703
+ mock_kalls = ([call (), call ('foo' ), call ('baz' )])
704
+ self .assertEqual (self .mock .mock_calls , mock_kalls )
705
+
706
+ def test_assert_has_mock_calls_on_async_mock_with_spec (self ):
707
+ a_class_mock = AsyncMock (AsyncClass )
708
+ with self .assertWarns (RuntimeWarning ):
709
+ # Will raise a warning because never awaited
710
+ a_class_mock .async_method ()
711
+ kalls_empty = [('' , (), {})]
712
+ self .assertEqual (a_class_mock .async_method .mock_calls , kalls_empty )
713
+ self .assertEqual (a_class_mock .mock_calls , [call .async_method ()])
714
+
715
+ with self .assertWarns (RuntimeWarning ):
716
+ # Will raise a warning because never awaited
717
+ a_class_mock .async_method (1 , 2 , 3 , a = 4 , b = 5 )
718
+ method_kalls = [call (), call (1 , 2 , 3 , a = 4 , b = 5 )]
719
+ mock_kalls = [call .async_method (), call .async_method (1 , 2 , 3 , a = 4 , b = 5 )]
720
+ self .assertEqual (a_class_mock .async_method .mock_calls , method_kalls )
721
+ self .assertEqual (a_class_mock .mock_calls , mock_kalls )
722
+
723
+ def test_async_method_calls_recorded (self ):
724
+ with self .assertWarns (RuntimeWarning ):
725
+ # Will raise warnings because never awaited
726
+ self .mock .something (3 , fish = None )
727
+ self .mock .something_else .something (6 , cake = sentinel .Cake )
728
+
729
+ self .assertEqual (self .mock .method_calls , [
730
+ ("something" , (3 ,), {'fish' : None }),
731
+ ("something_else.something" , (6 ,), {'cake' : sentinel .Cake })
732
+ ],
733
+ "method calls not recorded correctly" )
734
+ self .assertEqual (self .mock .something_else .method_calls ,
735
+ [("something" , (6 ,), {'cake' : sentinel .Cake })],
736
+ "method calls not recorded correctly" )
737
+
738
+ def test_async_arg_lists (self ):
739
+ def assert_attrs (mock ):
740
+ names = ('call_args_list' , 'method_calls' , 'mock_calls' )
741
+ for name in names :
742
+ attr = getattr (mock , name )
743
+ self .assertIsInstance (attr , _CallList )
744
+ self .assertIsInstance (attr , list )
745
+ self .assertEqual (attr , [])
746
+
747
+ assert_attrs (self .mock )
748
+ with self .assertWarns (RuntimeWarning ):
749
+ # Will raise warnings because never awaited
750
+ self .mock ()
751
+ self .mock (1 , 2 )
752
+ self .mock (a = 3 )
753
+
754
+ self .mock .reset_mock ()
755
+ assert_attrs (self .mock )
756
+
757
+ a_mock = AsyncMock (AsyncClass )
758
+ with self .assertWarns (RuntimeWarning ):
759
+ # Will raise warnings because never awaited
760
+ a_mock .async_method ()
761
+ a_mock .async_method (1 , a = 3 )
762
+
763
+ a_mock .reset_mock ()
764
+ assert_attrs (a_mock )
603
765
604
766
def test_assert_awaited (self ):
605
767
with self .assertRaises (AssertionError ):
@@ -645,20 +807,20 @@ def test_assert_awaited_once_with(self):
645
807
646
808
def test_assert_any_wait (self ):
647
809
with self .assertRaises (AssertionError ):
648
- self .mock .assert_any_await ('NormalFoo ' )
810
+ self .mock .assert_any_await ('foo ' )
649
811
650
- asyncio .run (self ._runnable_test ('foo ' ))
812
+ asyncio .run (self ._runnable_test ('baz ' ))
651
813
with self .assertRaises (AssertionError ):
652
- self .mock .assert_any_await ('NormalFoo ' )
814
+ self .mock .assert_any_await ('foo ' )
653
815
654
- asyncio .run (self ._runnable_test ('NormalFoo ' ))
655
- self .mock .assert_any_await ('NormalFoo ' )
816
+ asyncio .run (self ._runnable_test ('foo ' ))
817
+ self .mock .assert_any_await ('foo ' )
656
818
657
819
asyncio .run (self ._runnable_test ('SomethingElse' ))
658
- self .mock .assert_any_await ('NormalFoo ' )
820
+ self .mock .assert_any_await ('foo ' )
659
821
660
822
def test_assert_has_awaits_no_order (self ):
661
- calls = [call ('NormalFoo ' ), call ('baz' )]
823
+ calls = [call ('foo ' ), call ('baz' )]
662
824
663
825
with self .assertRaises (AssertionError ) as cm :
664
826
self .mock .assert_has_awaits (calls )
@@ -668,7 +830,7 @@ def test_assert_has_awaits_no_order(self):
668
830
with self .assertRaises (AssertionError ):
669
831
self .mock .assert_has_awaits (calls )
670
832
671
- asyncio .run (self ._runnable_test ('NormalFoo ' ))
833
+ asyncio .run (self ._runnable_test ('foo ' ))
672
834
with self .assertRaises (AssertionError ):
673
835
self .mock .assert_has_awaits (calls )
674
836
@@ -703,19 +865,19 @@ async def _custom_mock_runnable_test(*args):
703
865
mock_with_spec .assert_any_await (ANY , 1 )
704
866
705
867
def test_assert_has_awaits_ordered (self ):
706
- calls = [call ('NormalFoo ' ), call ('baz' )]
868
+ calls = [call ('foo ' ), call ('baz' )]
707
869
with self .assertRaises (AssertionError ):
708
870
self .mock .assert_has_awaits (calls , any_order = True )
709
871
710
872
asyncio .run (self ._runnable_test ('baz' ))
711
873
with self .assertRaises (AssertionError ):
712
874
self .mock .assert_has_awaits (calls , any_order = True )
713
875
714
- asyncio .run (self ._runnable_test ('foo ' ))
876
+ asyncio .run (self ._runnable_test ('bamf ' ))
715
877
with self .assertRaises (AssertionError ):
716
878
self .mock .assert_has_awaits (calls , any_order = True )
717
879
718
- asyncio .run (self ._runnable_test ('NormalFoo ' ))
880
+ asyncio .run (self ._runnable_test ('foo ' ))
719
881
self .mock .assert_has_awaits (calls , any_order = True )
720
882
721
883
asyncio .run (self ._runnable_test ('qux' ))
0 commit comments