1
+ import json
1
2
from typing import Any , Dict , TypeVar , Union
2
3
from unittest .mock import patch
3
4
@@ -49,16 +50,19 @@ async def test_should_call_validate_with_the_same_payload_object():
49
50
class DummyClaimValidator (SessionClaimValidator ):
50
51
def __init__ (self , claim : SessionClaim [Any ]):
51
52
super ().__init__ ("claim_validator_id" , claim )
52
- self .validate_call_count = 0
53
+ self .validate_calls : Dict [ str , int ] = {}
53
54
54
55
async def validate (
55
56
self , payload : JSONObject , user_context : Union [Dict [str , Any ], None ] = None
56
57
):
57
- self .validate_call_count += 1
58
+ payload_json = json .dumps (payload )
59
+ self .validate_calls [payload_json ] = (
60
+ self .validate_calls .get (payload_json , 0 ) + 1
61
+ )
58
62
return ClaimValidationResult (is_valid = True )
59
63
60
64
def should_refetch (self , payload : JSONObject , user_context : Dict [str , Any ]):
61
- return True
65
+ return False
62
66
63
67
dummy_claim = PrimitiveClaim ("st-claim" , lambda _ , __ : "Hello world" )
64
68
@@ -72,6 +76,6 @@ def should_refetch(self, payload: JSONObject, user_context: Dict[str, Any]):
72
76
wraps = session .update_access_token_payload ,
73
77
) as mock :
74
78
await session .assert_claims ([dummy_claim .validators .dummy_claim_validator ]) # type: ignore
75
- mock .assert_not_called ()
76
79
77
- assert dummy_claim_validator .validate_call_count == 1
80
+ assert dummy_claim_validator .validate_calls == {json .dumps (payload ): 1 }
81
+ mock .assert_not_called ()
0 commit comments