Skip to content

Commit 63eb59b

Browse files
committed
Merge branch 'feat/session-grants' into feat/ev-claim
2 parents 3df9f16 + e58cb87 commit 63eb59b

16 files changed

+191
-118
lines changed

supertokens_python/recipe/session/asyncio/__init__.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@
2424
SessionInformationResult,
2525
SessionClaim,
2626
SessionClaimValidator,
27-
SessionDoesnotExistError,
27+
SessionDoesNotExistError,
2828
ClaimsValidationResult,
2929
JSONObject,
3030
GetClaimValueOkResult,
3131
)
3232
from supertokens_python.recipe.session.recipe import SessionRecipe
3333
from supertokens_python.types import MaybeAwaitable
34-
from supertokens_python.utils import FRAMEWORKS, resolve
34+
from supertokens_python.utils import FRAMEWORKS, resolve, deprecated_warn
3535
from ..utils import get_required_claim_validators
3636
from ...jwt.interfaces import (
3737
CreateJwtOkResult,
@@ -53,12 +53,31 @@ async def create_new_session(
5353
) -> SessionContainer:
5454
if user_context is None:
5555
user_context = {}
56+
if session_data is None:
57+
session_data = {}
58+
if access_token_payload is None:
59+
access_token_payload = {}
60+
61+
claims_added_by_other_recipes = (
62+
SessionRecipe.get_instance().get_claims_added_by_other_recipes()
63+
)
64+
final_access_token_payload = access_token_payload
65+
66+
for claim in claims_added_by_other_recipes:
67+
update = await claim.build(user_id, user_context)
68+
final_access_token_payload = {**final_access_token_payload, **update}
69+
5670
if not hasattr(request, "wrapper_used") or not request.wrapper_used:
5771
request = FRAMEWORKS[
5872
SessionRecipe.get_instance().app_info.framework
5973
].wrap_request(request)
74+
6075
return await SessionRecipe.get_instance().recipe_implementation.create_new_session(
61-
request, user_id, access_token_payload, session_data, user_context=user_context
76+
request,
77+
user_id,
78+
final_access_token_payload,
79+
session_data,
80+
user_context=user_context,
6281
)
6382

6483

@@ -75,7 +94,7 @@ async def validate_claims_for_session_handle(
7594
]
7695
] = None,
7796
user_context: Union[None, Dict[str, Any]] = None,
78-
) -> Union[SessionDoesnotExistError, ClaimsValidationResult]:
97+
) -> Union[SessionDoesNotExistError, ClaimsValidationResult]:
7998
if user_context is None:
8099
user_context = {}
81100

@@ -85,10 +104,10 @@ async def validate_claims_for_session_handle(
85104
)
86105

87106
if session_info is None:
88-
return SessionDoesnotExistError()
107+
return SessionDoesNotExistError()
89108

90109
claim_validators_added_by_other_recipes = (
91-
SessionRecipe.get_claim_validators_added_by_other_recipes()
110+
SessionRecipe.get_instance().get_claim_validators_added_by_other_recipes()
92111
)
93112
global_claim_validators = await resolve(
94113
recipe_impl.get_global_claim_validators(
@@ -133,7 +152,7 @@ async def validate_claims_in_jwt_payload(
133152
recipe_impl = SessionRecipe.get_instance().recipe_implementation
134153

135154
claim_validators_added_by_other_recipes = (
136-
SessionRecipe.get_claim_validators_added_by_other_recipes()
155+
SessionRecipe.get_instance().get_claim_validators_added_by_other_recipes()
137156
)
138157
global_claim_validators = await resolve(
139158
recipe_impl.get_global_claim_validators(
@@ -173,7 +192,7 @@ async def get_claim_value(
173192
session_handle: str,
174193
claim: SessionClaim[_T],
175194
user_context: Union[None, Dict[str, Any]] = None,
176-
) -> Union[SessionDoesnotExistError, GetClaimValueOkResult[_T]]:
195+
) -> Union[SessionDoesNotExistError, GetClaimValueOkResult[_T]]:
177196
if user_context is None:
178197
user_context = {}
179198
return await SessionRecipe.get_instance().recipe_implementation.get_claim_value(
@@ -325,6 +344,11 @@ async def update_access_token_payload(
325344
) -> bool:
326345
if user_context is None:
327346
user_context = {}
347+
348+
deprecated_warn(
349+
"update_access_token_payload is deprecated. Use merge_into_access_token_payload instead"
350+
)
351+
328352
return await SessionRecipe.get_instance().recipe_implementation.update_access_token_payload(
329353
session_handle, new_access_token_payload, user_context
330354
)
@@ -337,7 +361,7 @@ async def merge_into_access_token_payload(
337361
) -> bool:
338362
if user_context is None:
339363
user_context = {}
340-
# TODO:
364+
341365
return await SessionRecipe.get_instance().recipe_implementation.merge_into_access_token_payload(
342366
session_handle, new_access_token_payload, user_context
343367
)

supertokens_python/recipe/session/claim_base_classes/boolean_claim.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,4 @@ def __init__(
4141
],
4242
):
4343
super().__init__(key, fetch_value)
44-
claim = self
45-
self.validators = BooleanClaimValidators(claim)
44+
self.validators = BooleanClaimValidators(claim=self)

supertokens_python/recipe/session/claim_base_classes/primitive_claim.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@
2929

3030

3131
class HasValueSCV(SessionClaimValidator):
32-
def __init__(self, id_: str, claim: SessionClaim[_T], params: Dict[str, Any]):
32+
def __init__(self, id_: str, claim: SessionClaim[_T], val: _T):
3333
super().__init__(id_)
34-
self.claim: SessionClaim[_T] = claim
35-
self.params = params
34+
self.claim: SessionClaim[_T] = claim # Required to fix the type for pyright
35+
self.val = val
3636

3737
def should_refetch(
3838
self,
@@ -46,9 +46,9 @@ async def validate(
4646
payload: JSONObject,
4747
user_context: Dict[str, Any],
4848
):
49-
val = self.params["val"]
49+
val = self.val
5050
claim_val = self.claim.get_value_from_payload(payload, user_context)
51-
is_valid = claim_val == val
51+
is_valid = claim_val == val # type: ignore
5252
if is_valid:
5353
return ClaimValidationResult(is_valid=True)
5454

@@ -63,17 +63,18 @@ async def validate(
6363

6464

6565
class HasFreshValueSCV(SessionClaimValidator):
66-
def __init__(self, id_: str, claim: SessionClaim[_T], params: Dict[str, Any]):
66+
def __init__(self, id_: str, claim: SessionClaim[_T], val: _T, max_age_in_sec: int):
6767
super().__init__(id_)
6868
self.claim: SessionClaim[_T] = claim
69-
self.params = params
69+
self.val = val
70+
self.max_age_in_sec = max_age_in_sec
7071

7172
def should_refetch(
7273
self,
7374
payload: JSONObject,
7475
user_context: Dict[str, Any],
7576
):
76-
max_age_in_sec: int = self.params["max_age_in_sec"]
77+
max_age_in_sec: int = self.max_age_in_sec
7778

7879
# (claim value is None) OR (value has expired)
7980
return (self.claim.get_value_from_payload(payload, user_context) is None) or (
@@ -85,8 +86,8 @@ async def validate(
8586
payload: JSONObject,
8687
user_context: Dict[str, Any],
8788
):
88-
val: str = self.params["val"]
89-
max_age_in_sec: int = self.params["max_age_in_sec"]
89+
val = self.val
90+
max_age_in_sec = self.max_age_in_sec
9091

9192
claim_val = self.claim.get_value_from_payload(payload, user_context)
9293
if claim_val is None:
@@ -113,7 +114,7 @@ async def validate(
113114
},
114115
)
115116

116-
if claim_val != val:
117+
if claim_val != val: # type: ignore
117118
return ClaimValidationResult(
118119
is_valid=False,
119120
reason={
@@ -132,15 +133,16 @@ def __init__(self, claim: SessionClaim[_T]) -> None:
132133
self.claim = claim
133134

134135
def has_value(self, val: _T, id_: Union[str, None] = None) -> SessionClaimValidator:
135-
return HasValueSCV((id_ or self.claim.key), self.claim, {"val": val})
136+
return HasValueSCV((id_ or self.claim.key), self.claim, val=val)
136137

137138
def has_fresh_value(
138139
self, val: _T, max_age_in_sec: int, id_: Union[str, None] = None
139140
) -> SessionClaimValidator:
140141
return HasFreshValueSCV(
141142
(id_ or (self.claim.key + "-fresh-val")),
142143
self.claim,
143-
{"val": val, "max_age_in_sec": max_age_in_sec},
144+
val=val,
145+
max_age_in_sec=max_age_in_sec,
144146
)
145147

146148

supertokens_python/recipe/session/exceptions.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# under the License.
1414
from __future__ import annotations
1515

16-
from typing import Union, Any, List, Dict
16+
from typing import Union, Any, List, Dict, Optional
1717

1818
from supertokens_python.exceptions import SuperTokensError
1919

@@ -56,13 +56,18 @@ class TryRefreshTokenError(SuperTokensSessionError):
5656
class InvalidClaimsError(SuperTokensSessionError):
5757
def __init__(self, msg: str, payload: List[ClaimValidationError]):
5858
super().__init__(msg)
59-
self.payload = [
60-
p.__dict__ for p in payload
61-
] # Must be JSON serializable as it will be used in response
59+
self.payload: List[Dict[str, Any]] = []
60+
for p in payload:
61+
res = (
62+
p.__dict__.copy()
63+
) # Must be JSON serializable as it will be used in response
64+
if p.reason is None:
65+
res.pop("reason")
66+
self.payload.append(res)
6267

6368

6469
class ClaimValidationError:
65-
def __init__(self, id_: str, reason: Dict[str, Any]):
70+
def __init__(self, id_: str, reason: Optional[Dict[str, Any]]):
6671
self.id = id_
6772
self.reason = reason
6873

supertokens_python/recipe/session/interfaces.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
FetchValueReturnType = Union[_T, None]
8383

8484

85-
class SessionDoesnotExistError:
85+
class SessionDoesNotExistError:
8686
pass
8787

8888

@@ -145,7 +145,7 @@ async def validate_claims_for_session_handle(
145145
session_info: SessionInformationResult,
146146
claim_validators: List[SessionClaimValidator],
147147
user_context: Dict[str, Any],
148-
) -> Union[ClaimsValidationResult, SessionDoesnotExistError]:
148+
) -> Union[ClaimsValidationResult, SessionDoesNotExistError]:
149149
pass
150150

151151
@abstractmethod
@@ -210,8 +210,6 @@ async def update_access_token_payload(
210210
new_access_token_payload: Dict[str, Any],
211211
user_context: Dict[str, Any],
212212
) -> bool:
213-
# TODO: Deprecate this method.
214-
# TODO: need to mark updateAccessTokenPayload as deprecated
215213
"""DEPRECATED: Use merge_into_access_token_payload instead"""
216214

217215
@abstractmethod
@@ -256,7 +254,7 @@ async def get_claim_value(
256254
session_handle: str,
257255
claim: SessionClaim[Any],
258256
user_context: Dict[str, Any],
259-
) -> Union[SessionDoesnotExistError, GetClaimValueOkResult[Any]]:
257+
) -> Union[SessionDoesNotExistError, GetClaimValueOkResult[Any]]:
260258
pass
261259

262260
@abstractmethod
@@ -451,7 +449,9 @@ async def get_claim_value(
451449

452450
@abstractmethod
453451
async def remove_claim(
454-
self, claim: SessionClaim[Any], user_context: Union[Dict[str, Any], None] = None
452+
self,
453+
claim: SessionClaim[Any],
454+
user_context: Union[Dict[str, Any], None] = None,
455455
) -> None:
456456
pass
457457

@@ -473,6 +473,15 @@ def sync_get_time_created(
473473
) -> int:
474474
return sync(self.get_time_created(user_context))
475475

476+
def sync_merge_into_access_token_payload(
477+
self, access_token_payload_update: Dict[str, Any], user_context: Dict[str, Any]
478+
) -> None:
479+
return sync(
480+
self.merge_into_access_token_payload(
481+
access_token_payload_update, user_context
482+
)
483+
)
484+
476485
def sync_update_access_token_payload(
477486
self,
478487
new_access_token_payload: Dict[str, Any],
@@ -590,23 +599,20 @@ def __init__(self, is_valid: bool, reason: Optional[Dict[str, Any]] = None):
590599

591600

592601
class SessionClaimValidator(ABC):
593-
def __init__(self, id_: str):
602+
def __init__(
603+
self,
604+
id_: str,
605+
) -> None:
594606
self.id = id_
595-
self.claim: Optional[
596-
SessionClaim[Any]
597-
] = None # Child class must set this if required.
607+
self.claim: Optional[SessionClaim[Any]] = None
598608

599609
@abstractmethod
600610
async def validate(
601611
self, payload: JSONObject, user_context: Dict[str, Any]
602612
) -> ClaimValidationResult:
603613
pass
604614

605-
def should_refetch( # pylint: disable=no-self-use
615+
def should_refetch(
606616
self, payload: JSONObject, user_context: Dict[str, Any]
607617
) -> MaybeAwaitable[bool]:
608-
# TODO: https://github.com/supertokens/supertokens-python/pull/209#discussion_r932121943
609-
# TODO: This should also be an abstractmethod
610-
# TODO: This should also be async
611-
_, __ = payload, user_context
612-
return False
618+
raise NotImplementedError()

supertokens_python/recipe/session/recipe.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,6 @@ class SessionRecipe(RecipeModule):
6666
recipe_id = "session"
6767
__instance = None
6868

69-
claims_added_by_other_recipes: List[SessionClaim[Any]] = []
70-
claim_validators_added_by_other_recipes: List[SessionClaimValidator] = []
71-
7269
def __init__(
7370
self,
7471
recipe_id: str,
@@ -156,6 +153,9 @@ def __init__(
156153
else self.config.override.apis(api_implementation)
157154
)
158155

156+
self.claims_added_by_other_recipes: List[SessionClaim[Any]] = []
157+
self.claim_validators_added_by_other_recipes: List[SessionClaimValidator] = []
158+
159159
def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool:
160160
return isinstance(err, SuperTokensError) and (
161161
isinstance(err, SuperTokensSessionError)
@@ -302,31 +302,28 @@ def reset():
302302
):
303303
raise_general_exception("calling testing function in non testing env")
304304
SessionRecipe.__instance = None
305-
# FIXME: Discovered its requirement while running tests. Confirm if this is correct:
306-
SessionRecipe.claims_added_by_other_recipes = []
307-
SessionRecipe.claim_validators_added_by_other_recipes = []
308305

309-
@staticmethod
310-
def add_claim_from_other_recipe(claim: SessionClaim[Any]):
306+
def add_claim_from_other_recipe(self, claim: SessionClaim[Any]):
311307
# We are throwing here (and not in addClaimValidatorFromOtherRecipe) because if multiple
312308
# claims are added with the same key they will overwrite each other. Validators will all run
313309
# and work as expected even if they are added multiple times.
314-
if claim.key in [c.key for c in SessionRecipe.claims_added_by_other_recipes]:
310+
if claim.key in [c.key for c in self.claims_added_by_other_recipes]:
315311
raise Exception("Claim added by multiple recipes")
316312

317-
SessionRecipe.claims_added_by_other_recipes.append(claim)
313+
self.claims_added_by_other_recipes.append(claim)
318314

319-
@staticmethod
320-
def get_claims_added_by_other_recipes() -> List[SessionClaim[Any]]:
321-
return SessionRecipe.claims_added_by_other_recipes
315+
def get_claims_added_by_other_recipes(self) -> List[SessionClaim[Any]]:
316+
return self.claims_added_by_other_recipes
322317

323-
@staticmethod
324-
def add_claim_validator_from_other_recipe(claim_validator: SessionClaimValidator):
325-
SessionRecipe.claim_validators_added_by_other_recipes.append(claim_validator)
318+
def add_claim_validator_from_other_recipe(
319+
self, claim_validator: SessionClaimValidator
320+
):
321+
self.claim_validators_added_by_other_recipes.append(claim_validator)
326322

327-
@staticmethod
328-
def get_claim_validators_added_by_other_recipes() -> List[SessionClaimValidator]:
329-
return SessionRecipe.claim_validators_added_by_other_recipes
323+
def get_claim_validators_added_by_other_recipes(
324+
self,
325+
) -> List[SessionClaimValidator]:
326+
return self.claim_validators_added_by_other_recipes
330327

331328
async def verify_session(
332329
self,

0 commit comments

Comments
 (0)