Skip to content

Commit 0ac965c

Browse files
committed
test invalid claims response with and without reason
1 parent dc9c3be commit 0ac965c

File tree

3 files changed

+37
-16
lines changed

3 files changed

+37
-16
lines changed

supertokens_python/recipe/session/exceptions.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# under the License.
1414
from __future__ import annotations
1515

16-
from typing import Union, Any, List, Dict
16+
from typing import Union, Any, List, Dict, Optional
1717

1818
from supertokens_python.exceptions import SuperTokensError
1919

@@ -56,13 +56,18 @@ class TryRefreshTokenError(SuperTokensSessionError):
5656
class InvalidClaimsError(SuperTokensSessionError):
5757
def __init__(self, msg: str, payload: List[ClaimValidationError]):
5858
super().__init__(msg)
59-
self.payload = [
60-
p.__dict__ for p in payload
61-
] # Must be JSON serializable as it will be used in response
59+
self.payload: List[Dict[str, Any]] = []
60+
for p in payload:
61+
res = (
62+
p.__dict__.copy()
63+
) # Must be JSON serializable as it will be used in response
64+
if p.reason is None:
65+
res.pop("reason")
66+
self.payload.append(res)
6267

6368

6469
class ClaimValidationError:
65-
def __init__(self, id_: str, reason: Dict[str, Any]):
70+
def __init__(self, id_: str, reason: Optional[Dict[str, Any]]):
6671
self.id = id_
6772
self.reason = reason
6873

supertokens_python/recipe/session/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,8 +565,6 @@ async def validate_claims_in_payload(
565565
json.dumps(claim_validation_res.__dict__),
566566
)
567567
if not claim_validation_res.is_valid:
568-
assert claim_validation_res.reason is not None
569-
570568
validation_errors.append(
571569
ClaimValidationError(validator.id, claim_validation_res.reason)
572570
)

tests/sessions/claims/test_verify_session.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Any, Dict, Union
1+
from typing import List, Any, Dict, Union, Optional
22
from unittest.mock import patch
33

44
from fastapi import FastAPI, Depends
@@ -99,13 +99,14 @@ def should_refetch(self, payload: JSONObject, user_context: Dict[str, Any]):
9999

100100

101101
class AlwaysInvalidValidator(SessionClaimValidator):
102-
def __init__(self):
102+
def __init__(self, reason: Optional[Dict[str, Any]]):
103103
super().__init__("always-invalid-validator", TrueClaim)
104+
self.reason = reason
104105

105106
async def validate(
106107
self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None
107108
) -> ClaimValidationResult:
108-
return ClaimValidationResult(is_valid=False, reason={"message": "foo"})
109+
return ClaimValidationResult(is_valid=False, reason=self.reason)
109110

110111
def should_refetch(self, payload: JSONObject, user_context: Dict[str, Any]) -> bool:
111112
return True
@@ -290,10 +291,28 @@ async def test_should_allow_with_custom_validator_returning_true(
290291
assert "-" in res.json()["handle"]
291292

292293

293-
async def test_should_reject_with_custom_validator_returning_false(
294+
async def test_should_reject_with_custom_validator_returning_false_without_reason(
294295
fastapi_client: TestClient,
295296
):
296-
custom_validator = AlwaysInvalidValidator()
297+
custom_validator = AlwaysInvalidValidator(reason=None)
298+
299+
st_init_args = st_init_generator_with_claim_validator(custom_validator)
300+
init(**st_init_args) # type: ignore
301+
start_st()
302+
303+
create_session(fastapi_client)
304+
response = fastapi_client.get("/default-claims")
305+
assert response.status_code == 403
306+
assert response.json() == {
307+
"message": "invalid claim",
308+
"claimValidationErrors": [{"id": "always-invalid-validator"}],
309+
}
310+
311+
312+
async def test_should_reject_with_custom_validator_returning_false_with_reason(
313+
fastapi_client: TestClient,
314+
):
315+
custom_validator = AlwaysInvalidValidator(reason={"message": "foo"})
297316

298317
st_init_args = st_init_generator_with_claim_validator(custom_validator)
299318
init(**st_init_args) # type: ignore
@@ -310,9 +329,6 @@ async def test_should_reject_with_custom_validator_returning_false(
310329
}
311330

312331

313-
# should reject with validator returning false with reason (Leaving this. It's exactly same as prev.)
314-
315-
316332
async def test_should_reject_if_assert_claims_returns_an_error(
317333
fastapi_client: TestClient,
318334
):
@@ -440,7 +456,9 @@ async def test_should_reject_with_custom_claim_returning_false(
440456
):
441457
# This gets overriden by override_global_claim_validators passed to verify_session()
442458
# in "/refetched-claim-isvalid-false" api
443-
cv = AlwaysInvalidValidator()
459+
cv = AlwaysInvalidValidator(
460+
reason={"message": "does not matter because of override"}
461+
)
444462
st_init_args = st_init_generator_with_claim_validator(cv)
445463
init(**st_init_args) # type: ignore
446464
start_st()

0 commit comments

Comments
 (0)