Skip to content

Commit c26242e

Browse files
committed
feat: Changes suggested in feedback
1 parent 54e2736 commit c26242e

File tree

14 files changed

+112
-73
lines changed

14 files changed

+112
-73
lines changed

supertokens_python/recipe/session/asyncio/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
SessionInformationResult,
2323
SessionClaim,
2424
SessionClaimValidator,
25-
SessionDoesnotExistError,
25+
SessionDoesNotExistError,
2626
ClaimsValidationResult,
2727
JSONObject,
2828
GetClaimValueOkResult,
@@ -71,7 +71,7 @@ async def validate_claims_for_session_handle(
7171
]
7272
] = None,
7373
user_context: Union[None, Dict[str, Any]] = None,
74-
) -> Union[SessionDoesnotExistError, ClaimsValidationResult]:
74+
) -> Union[SessionDoesNotExistError, ClaimsValidationResult]:
7575
if user_context is None:
7676
user_context = {}
7777

@@ -81,7 +81,7 @@ async def validate_claims_for_session_handle(
8181
)
8282

8383
if session_info is None:
84-
return SessionDoesnotExistError()
84+
return SessionDoesNotExistError()
8585

8686
claim_validators_added_by_other_recipes = (
8787
SessionRecipe.get_claim_validators_added_by_other_recipes()
@@ -169,7 +169,7 @@ async def get_claim_value(
169169
session_handle: str,
170170
claim: SessionClaim[_T],
171171
user_context: Union[None, Dict[str, Any]] = None,
172-
) -> Union[SessionDoesnotExistError, GetClaimValueOkResult[_T]]:
172+
) -> Union[SessionDoesNotExistError, GetClaimValueOkResult[_T]]:
173173
if user_context is None:
174174
user_context = {}
175175
return await SessionRecipe.get_instance().recipe_implementation.get_claim_value(
@@ -333,7 +333,7 @@ async def merge_into_access_token_payload(
333333
) -> bool:
334334
if user_context is None:
335335
user_context = {}
336-
# TODO:
336+
337337
return await SessionRecipe.get_instance().recipe_implementation.merge_into_access_token_payload(
338338
session_handle, new_access_token_payload, user_context
339339
)

supertokens_python/recipe/session/claim_base_classes/boolean_claim.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,4 @@ def __init__(
4141
],
4242
):
4343
super().__init__(key, fetch_value)
44-
claim = self
45-
self.validators = BooleanClaimValidators(claim)
44+
self.validators = BooleanClaimValidators(claim=self)

supertokens_python/recipe/session/claim_base_classes/primitive_claim.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,9 @@
2929

3030

3131
class HasValueSCV(SessionClaimValidator):
32-
def __init__(self, id_: str, claim: SessionClaim[_T], params: Dict[str, Any]):
33-
super().__init__(id_)
34-
self.claim: SessionClaim[_T] = claim
35-
self.params = params
32+
def __init__(self, id_: str, claim: SessionClaim[_T], val: _T):
33+
super().__init__(id_, claim)
34+
self.val = val
3635

3736
def should_refetch(
3837
self,
@@ -46,9 +45,9 @@ async def validate(
4645
payload: JSONObject,
4746
user_context: Dict[str, Any],
4847
):
49-
val = self.params["val"]
48+
val = self.val
5049
claim_val = self.claim.get_value_from_payload(payload, user_context)
51-
is_valid = claim_val == val
50+
is_valid = claim_val == val # type: ignore
5251
if is_valid:
5352
return ClaimValidationResult(is_valid=True)
5453

@@ -63,17 +62,17 @@ async def validate(
6362

6463

6564
class HasFreshValueSCV(SessionClaimValidator):
66-
def __init__(self, id_: str, claim: SessionClaim[_T], params: Dict[str, Any]):
67-
super().__init__(id_)
68-
self.claim: SessionClaim[_T] = claim
69-
self.params = params
65+
def __init__(self, id_: str, claim: SessionClaim[_T], val: _T, max_age_in_sec: int):
66+
super().__init__(id_, claim)
67+
self.val = val
68+
self.max_age_in_sec = max_age_in_sec
7069

7170
def should_refetch(
7271
self,
7372
payload: JSONObject,
7473
user_context: Dict[str, Any],
7574
):
76-
max_age_in_sec: int = self.params["max_age_in_sec"]
75+
max_age_in_sec: int = self.max_age_in_sec
7776

7877
# (claim value is None) OR (value has expired)
7978
return (self.claim.get_value_from_payload(payload, user_context) is None) or (
@@ -85,8 +84,8 @@ async def validate(
8584
payload: JSONObject,
8685
user_context: Dict[str, Any],
8786
):
88-
val: str = self.params["val"]
89-
max_age_in_sec: int = self.params["max_age_in_sec"]
87+
val = self.val
88+
max_age_in_sec = self.max_age_in_sec
9089

9190
claim_val = self.claim.get_value_from_payload(payload, user_context)
9291
if claim_val is None:
@@ -113,7 +112,7 @@ async def validate(
113112
},
114113
)
115114

116-
if claim_val != val:
115+
if claim_val != val: # type: ignore
117116
return ClaimValidationResult(
118117
is_valid=False,
119118
reason={
@@ -131,15 +130,16 @@ def __init__(self, claim: SessionClaim[_T]) -> None:
131130
self.claim = claim
132131

133132
def has_value(self, val: _T, id_: Union[str, None] = None) -> SessionClaimValidator:
134-
return HasValueSCV((id_ or self.claim.key), self.claim, {"val": val})
133+
return HasValueSCV((id_ or self.claim.key), self.claim, val=val)
135134

136135
def has_fresh_value(
137136
self, val: _T, max_age_in_sec: int, id_: Union[str, None] = None
138137
) -> SessionClaimValidator:
139138
return HasFreshValueSCV(
140139
(id_ or (self.claim.key + "-fresh-val")),
141140
self.claim,
142-
{"val": val, "max_age_in_sec": max_age_in_sec},
141+
val=val,
142+
max_age_in_sec=max_age_in_sec,
143143
)
144144

145145

supertokens_python/recipe/session/interfaces.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
FetchValueReturnType = Union[_T, None]
8383

8484

85-
class SessionDoesnotExistError:
85+
class SessionDoesNotExistError:
8686
pass
8787

8888

@@ -145,7 +145,7 @@ async def validate_claims_for_session_handle(
145145
session_info: SessionInformationResult,
146146
claim_validators: List[SessionClaimValidator],
147147
user_context: Dict[str, Any],
148-
) -> Union[ClaimsValidationResult, SessionDoesnotExistError]:
148+
) -> Union[ClaimsValidationResult, SessionDoesNotExistError]:
149149
pass
150150

151151
@abstractmethod
@@ -210,7 +210,6 @@ async def update_access_token_payload(
210210
new_access_token_payload: Dict[str, Any],
211211
user_context: Dict[str, Any],
212212
) -> bool:
213-
# TODO: Deprecate this method.
214213
"""DEPRECATED: Use merge_into_access_token_payload instead"""
215214

216215
@abstractmethod
@@ -255,7 +254,7 @@ async def get_claim_value(
255254
session_handle: str,
256255
claim: SessionClaim[Any],
257256
user_context: Dict[str, Any],
258-
) -> Union[SessionDoesnotExistError, GetClaimValueOkResult[Any]]:
257+
) -> Union[SessionDoesNotExistError, GetClaimValueOkResult[Any]]:
259258
pass
260259

261260
@abstractmethod
@@ -450,7 +449,9 @@ async def get_claim_value(
450449

451450
@abstractmethod
452451
async def remove_claim(
453-
self, claim: SessionClaim[Any], user_context: Union[Dict[str, Any], None] = None
452+
self,
453+
claim: SessionClaim[_T], # pyright: ignore[reportInvalidTypeVarUse]
454+
user_context: Union[Dict[str, Any], None] = None,
454455
) -> None:
455456
pass
456457

@@ -472,6 +473,15 @@ def sync_get_time_created(
472473
) -> int:
473474
return sync(self.get_time_created(user_context))
474475

476+
def sync_merge_into_access_token_payload(
477+
self, access_token_payload_update: Dict[str, Any], user_context: Dict[str, Any]
478+
) -> None:
479+
return sync(
480+
self.merge_into_access_token_payload(
481+
access_token_payload_update, user_context
482+
)
483+
)
484+
475485
def sync_update_access_token_payload(
476486
self,
477487
new_access_token_payload: Dict[str, Any],
@@ -589,20 +599,23 @@ def __init__(self, is_valid: bool, reason: Optional[Dict[str, Any]] = None):
589599

590600

591601
class SessionClaimValidator(ABC):
592-
def __init__(self, id_: str):
602+
def __init__(
603+
self,
604+
id_: str,
605+
claim: SessionClaim[_T], # pyright: ignore[reportInvalidTypeVarUse]
606+
) -> None:
593607
self.id = id_
594-
self.claim: Optional[
595-
SessionClaim[Any]
596-
] = None # Child class must set this if required.
608+
self.claim = claim
597609

598610
@abstractmethod
599611
async def validate(
600612
self, payload: JSONObject, user_context: Dict[str, Any]
601613
) -> ClaimValidationResult:
602614
pass
603615

604-
def should_refetch( # pylint: disable=no-self-use
616+
@abstractmethod
617+
def should_refetch(
605618
self, payload: JSONObject, user_context: Dict[str, Any]
606619
) -> MaybeAwaitable[bool]:
607-
_, __ = payload, user_context
608-
return False
620+
# TODO: Confirm that MaybeAwaitable actually makes the function async
621+
pass

supertokens_python/recipe/session/recipe_implementation.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import json
1717
from typing import TYPE_CHECKING, Any, Dict, Optional
18-
1918
from supertokens_python.framework.request import BaseRequest
2019
from supertokens_python.logger import log_debug_message
2120
from supertokens_python.normalised_url_path import NormalisedURLPath
@@ -26,6 +25,7 @@
2625
get_timestamp_ms,
2726
normalise_http_method,
2827
resolve,
28+
deprecated_warn,
2929
)
3030
from . import session_functions
3131
from .cookie_and_header import (
@@ -45,7 +45,7 @@
4545
SessionInformationResult,
4646
SessionObj,
4747
ClaimsValidationResult,
48-
SessionDoesnotExistError,
48+
SessionDoesNotExistError,
4949
JSONObject,
5050
GetClaimValueOkResult,
5151
)
@@ -193,7 +193,7 @@ async def validate_claims_for_session_handle(
193193
session_info: SessionInformationResult,
194194
claim_validators: List[SessionClaimValidator],
195195
user_context: Dict[str, Any],
196-
) -> Union[ClaimsValidationResult, SessionDoesnotExistError]:
196+
) -> Union[ClaimsValidationResult, SessionDoesNotExistError]:
197197
original_session_claim_payload_json = json.dumps(
198198
session_info.access_token_payload
199199
)
@@ -212,7 +212,7 @@ async def validate_claims_for_session_handle(
212212
user_context,
213213
)
214214
if res is False:
215-
return SessionDoesnotExistError()
215+
return SessionDoesNotExistError()
216216

217217
invalid_claims = await validate_claims_in_payload(
218218
claim_validators,
@@ -402,7 +402,10 @@ async def update_access_token_payload(
402402
new_access_token_payload: Dict[str, Any],
403403
user_context: Dict[str, Any],
404404
) -> bool:
405-
"""DEPRECATED: Use merge_into_access_token_payload instead"""
405+
deprecated_warn(
406+
"update_access_token_payload is deprecated. Use merge_into_access_token_payload instead"
407+
)
408+
406409
return await session_functions.update_access_token_payload(
407410
self, session_handle, new_access_token_payload
408411
)
@@ -469,10 +472,10 @@ async def get_claim_value(
469472
session_handle: str,
470473
claim: SessionClaim[Any],
471474
user_context: Dict[str, Any],
472-
) -> Union[SessionDoesnotExistError, GetClaimValueOkResult[Any]]:
475+
) -> Union[SessionDoesNotExistError, GetClaimValueOkResult[Any]]:
473476
session_info = await self.get_session_information(session_handle, user_context)
474477
if session_info is None:
475-
return SessionDoesnotExistError()
478+
return SessionDoesNotExistError()
476479

477480
return GetClaimValueOkResult(
478481
value=claim.get_value_from_payload(

supertokens_python/recipe/session/session_class.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,12 @@ async def update_session_data(
6262

6363
async def update_access_token_payload(
6464
self,
65-
new_access_token_payload: Union[Dict[str, Any], None],
66-
user_context: Dict[str, Any],
65+
new_access_token_payload: Dict[str, Any],
66+
user_context: Union[Dict[str, Any], None] = None,
6767
) -> None:
68+
if user_context is None:
69+
user_context = {}
70+
6871
response = await self.recipe_implementation.regenerate_access_token(
6972
self.access_token, new_access_token_payload, user_context
7073
)

supertokens_python/recipe/session/syncio/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
SessionClaim,
3434
JSONObject,
3535
ClaimsValidationResult,
36-
SessionDoesnotExistError,
36+
SessionDoesNotExistError,
3737
GetClaimValueOkResult,
3838
)
3939

@@ -267,7 +267,7 @@ def get_claim_value(
267267
session_handle: str,
268268
claim: SessionClaim[_T],
269269
user_context: Union[None, Dict[str, Any]] = None,
270-
) -> Union[SessionDoesnotExistError, GetClaimValueOkResult[_T]]:
270+
) -> Union[SessionDoesNotExistError, GetClaimValueOkResult[_T]]:
271271
from supertokens_python.recipe.session.asyncio import (
272272
get_claim_value as async_get_claim_value,
273273
)
@@ -296,7 +296,7 @@ def validate_claims_for_session_handle(
296296
]
297297
] = None,
298298
user_context: Union[None, Dict[str, Any]] = None,
299-
) -> Union[SessionDoesnotExistError, ClaimsValidationResult]:
299+
) -> Union[SessionDoesNotExistError, ClaimsValidationResult]:
300300
from supertokens_python.recipe.session.asyncio import (
301301
validate_claims_for_session_handle as async_validate_claims_for_session_handle,
302302
)

supertokens_python/recipe/session/utils.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,17 @@ def validate_and_normalise_user_input(
418418
session_expired_status_code = (
419419
session_expired_status_code if session_expired_status_code is not None else 401
420420
)
421+
422+
invalid_claim_status_code = (
423+
invalid_claim_status_code if invalid_claim_status_code is not None else 403
424+
)
425+
426+
if session_expired_status_code == invalid_claim_status_code:
427+
raise Exception(
428+
"session_expired_status_code and invalid_claim_status_code cannot be the same "
429+
f"({invalid_claim_status_code})"
430+
)
431+
421432
if anti_csrf is None:
422433
anti_csrf = "VIA_CUSTOM_HEADER" if cookie_same_site == "none" else "NONE"
423434

@@ -463,8 +474,7 @@ def validate_and_normalise_user_input(
463474
app_info.framework,
464475
app_info.mode,
465476
jwt,
466-
invalid_claim_status_code if (invalid_claim_status_code is not None) else 403
467-
# TODO: above line was marked as TODO in review, not sure why.
477+
invalid_claim_status_code,
468478
)
469479

470480

@@ -554,10 +564,9 @@ async def validate_claims_in_payload(
554564
validator.id,
555565
json.dumps(claim_validation_res.__dict__),
556566
)
557-
if (
558-
not claim_validation_res.is_valid
559-
and claim_validation_res.reason is not None
560-
):
567+
if not claim_validation_res.is_valid:
568+
assert claim_validation_res.reason is not None
569+
561570
validation_errors.append(
562571
ClaimValidationError(validator.id, claim_validation_res.reason)
563572
)

0 commit comments

Comments
 (0)