Skip to content

Commit 349596a

Browse files
committed
feat: Add more functions related to session claims
1 parent abba205 commit 349596a

File tree

6 files changed

+306
-10
lines changed

6 files changed

+306
-10
lines changed

supertokens_python/recipe/session/asyncio/__init__.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
14-
from typing import Any, Dict, List, Union
14+
from typing import Any, Dict, List, Union, TypeVar
1515

1616
from supertokens_python.recipe.openid.interfaces import (
1717
GetOpenIdDiscoveryConfigurationResult,
@@ -20,16 +20,18 @@
2020
RegenerateAccessTokenOkResult,
2121
SessionContainer,
2222
SessionInformationResult,
23+
SessionClaim,
2324
)
2425
from supertokens_python.recipe.session.recipe import SessionRecipe
2526
from supertokens_python.utils import FRAMEWORKS
26-
2727
from ...jwt.interfaces import (
2828
CreateJwtOkResult,
2929
CreateJwtResultUnsupportedAlgorithm,
3030
GetJWKSResult,
3131
)
3232

33+
_T = TypeVar("_T")
34+
3335

3436
async def create_new_session(
3537
request: Any,
@@ -49,6 +51,43 @@ async def create_new_session(
4951
)
5052

5153

54+
async def get_claim_value(
55+
session_handle: str,
56+
claim: SessionClaim[_T],
57+
user_context: Union[None, Dict[str, Any]] = None,
58+
) -> Union[_T, None]:
59+
if user_context is None:
60+
user_context = {}
61+
return await SessionRecipe.get_instance().recipe_implementation.get_claim_value(
62+
session_handle, claim, user_context
63+
)
64+
65+
66+
async def set_claim_value(
67+
session_handle: str,
68+
claim: SessionClaim[_T],
69+
value: _T,
70+
user_context: Union[None, Dict[str, Any]] = None,
71+
) -> bool:
72+
if user_context is None:
73+
user_context = {}
74+
return await SessionRecipe.get_instance().recipe_implementation.set_claim_value(
75+
session_handle, claim, value, user_context
76+
)
77+
78+
79+
async def remove_claim(
80+
session_handle: str,
81+
claim: SessionClaim[Any],
82+
user_context: Union[None, Dict[str, Any]] = None,
83+
) -> bool:
84+
if user_context is None:
85+
user_context = {}
86+
return await SessionRecipe.get_instance().recipe_implementation.remove_claim(
87+
session_handle, claim, user_context
88+
)
89+
90+
5291
async def get_session(
5392
request: Any,
5493
anti_csrf_check: Union[bool, None] = None,

supertokens_python/recipe/session/interfaces.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from supertokens_python.async_to_sync_wrapper import sync
2121
from supertokens_python.types import APIResponse, GeneralErrorResponse
2222

23+
from ...utils import Promise
2324
from .utils import SessionConfig
2425

2526
if TYPE_CHECKING:
@@ -79,6 +80,15 @@ async def create_new_session(
7980
) -> SessionContainer:
8081
pass
8182

83+
@abstractmethod
84+
async def get_global_claim_validators(
85+
self,
86+
user_id: str,
87+
claim_validators_added_by_other_recipes: List[SessionClaimValidator],
88+
user_context: Dict[str, Any],
89+
) -> Union[SessionClaimValidator, Promise[SessionClaimValidator]]:
90+
pass
91+
8292
@abstractmethod
8393
async def get_session(
8494
self,
@@ -151,6 +161,43 @@ async def get_access_token_lifetime_ms(self, user_context: Dict[str, Any]) -> in
151161
async def get_refresh_token_lifetime_ms(self, user_context: Dict[str, Any]) -> int:
152162
pass
153163

164+
@abstractmethod
165+
async def fetch_and_set_claim(
166+
self,
167+
session_handle: str,
168+
claim: SessionClaim[Any],
169+
user_context: Dict[str, Any],
170+
) -> bool:
171+
pass
172+
173+
@abstractmethod
174+
async def set_claim_value(
175+
self,
176+
session_handle: str,
177+
claim: SessionClaim[_T],
178+
value: _T,
179+
user_context: Dict[str, Any],
180+
) -> bool:
181+
pass
182+
183+
@abstractmethod
184+
async def get_claim_value(
185+
self,
186+
session_handle: str,
187+
claim: SessionClaim[Any],
188+
user_context: Dict[str, Any],
189+
):
190+
pass
191+
192+
@abstractmethod
193+
async def remove_claim(
194+
self,
195+
session_handle: str,
196+
claim: SessionClaim[Any],
197+
user_context: Dict[str, Any],
198+
) -> bool:
199+
pass
200+
154201
@abstractmethod
155202
async def regenerate_access_token(
156203
self,
@@ -217,7 +264,7 @@ async def verify_session(
217264
pass
218265

219266

220-
class SessionContainer(ABC):
267+
class SessionContainer(ABC): # pylint: disable=too-many-public-methods
221268
def __init__(
222269
self,
223270
recipe_implementation: RecipeInterface,
@@ -260,6 +307,12 @@ async def update_access_token_payload(
260307
self,
261308
new_access_token_payload: Dict[str, Any],
262309
user_context: Union[Dict[str, Any], None] = None,
310+
) -> None:
311+
"""DEPRECATED: Use merge_into_access_token_payload instead"""
312+
313+
@abstractmethod
314+
async def merge_into_access_token_payload(
315+
self, access_token_payload_update: Dict[str, Any], user_context: Any
263316
) -> None:
264317
pass
265318

@@ -291,6 +344,41 @@ async def get_time_created(
291344
async def get_expiry(self, user_context: Union[Dict[str, Any], None] = None) -> int:
292345
pass
293346

347+
@abstractmethod
348+
async def assert_claims(
349+
self,
350+
claim_validators: List[SessionClaimValidator],
351+
user_context: Union[Dict[str, Any], None] = None,
352+
) -> None:
353+
pass
354+
355+
@abstractmethod
356+
async def fetch_and_set_claim(
357+
self, claim: SessionClaim[Any], user_context: Union[Dict[str, Any], None] = None
358+
) -> None:
359+
pass
360+
361+
@abstractmethod
362+
async def set_claim_value(
363+
self,
364+
claim: SessionClaim[_T],
365+
value: _T,
366+
user_context: Union[Dict[str, Any], None] = None,
367+
) -> None:
368+
pass
369+
370+
@abstractmethod
371+
async def get_claim_value(
372+
self, claim: SessionClaim[_T], user_context: Union[Dict[str, Any], None] = None
373+
) -> Union[_T, None]:
374+
pass
375+
376+
@abstractmethod
377+
async def remove_claim(
378+
self, claim: SessionClaim[Any], user_context: Union[Dict[str, Any], None] = None
379+
) -> None:
380+
pass
381+
294382
def sync_get_expiry(self, user_context: Union[Dict[str, Any], None] = None) -> int:
295383
return sync(self.get_expiry(user_context))
296384

@@ -334,7 +422,6 @@ def __getitem__(self, item: str):
334422
_T = TypeVar("_T")
335423
JSONObject = Dict[str, Any]
336424

337-
338425
JSONPrimitive = Union[str, int, bool, None, Dict[str, Any]]
339426

340427
FetchValueReturnType = Union[_T, None]

supertokens_python/recipe/session/recipe.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,13 @@
3939

4040
from .api.implementation import APIImplementation
4141
from .constants import SESSION_REFRESH, SIGNOUT
42-
from .interfaces import APIInterface, APIOptions, RecipeInterface
42+
from .interfaces import (
43+
APIInterface,
44+
APIOptions,
45+
RecipeInterface,
46+
SessionClaim,
47+
SessionClaimValidator,
48+
)
4349
from .recipe_implementation import RecipeImplementation
4450
from .utils import (
4551
InputErrorHandlers,
@@ -138,6 +144,9 @@ def __init__(
138144
else self.config.override.apis(api_implementation)
139145
)
140146

147+
self.claims_added_by_other_recipes: List[SessionClaim[Any]] = []
148+
self.claim_validators_added_by_other_recipes: List[SessionClaimValidator] = []
149+
141150
def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool:
142151
return isinstance(err, SuperTokensError) and (
143152
isinstance(err, SuperTokensSessionError)
@@ -278,6 +287,22 @@ def reset():
278287
raise_general_exception("calling testing function in non testing env")
279288
SessionRecipe.__instance = None
280289

290+
def add_claim_from_other_recipe(self, claim: SessionClaim[Any]):
291+
self.claims_added_by_other_recipes.append(claim)
292+
293+
def get_claims_added_by_other_recipes(self) -> List[SessionClaim[Any]]:
294+
return self.claims_added_by_other_recipes
295+
296+
def add_claim_validator_from_other_recipe(
297+
self, claim_validator: SessionClaimValidator
298+
):
299+
self.claim_validators_added_by_other_recipes.append(claim_validator)
300+
301+
def get_claim_validators_added_by_other_recipes(
302+
self,
303+
) -> List[SessionClaimValidator]:
304+
return self.claim_validators_added_by_other_recipes
305+
281306
async def verify_session(
282307
self,
283308
request: BaseRequest,

supertokens_python/recipe/session/recipe_implementation.py

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
# under the License.
1414
from __future__ import annotations
1515

16-
from typing import TYPE_CHECKING, Any, Dict
16+
from typing import TYPE_CHECKING, Any, Dict, TypeVar
1717

1818
from supertokens_python.framework.request import BaseRequest
1919
from supertokens_python.logger import log_debug_message
2020
from supertokens_python.normalised_url_path import NormalisedURLPath
2121
from supertokens_python.process_state import AllowedProcessStates, ProcessState
2222
from supertokens_python.utils import (
23+
Promise,
2324
execute_async,
2425
frontend_has_interceptor,
2526
get_timestamp_ms,
@@ -39,6 +40,8 @@
3940
AccessTokenObj,
4041
RecipeInterface,
4142
RegenerateAccessTokenOkResult,
43+
SessionClaim,
44+
SessionClaimValidator,
4245
SessionInformationResult,
4346
SessionObj,
4447
)
@@ -53,6 +56,8 @@
5356

5457
from .interfaces import SessionContainer
5558

59+
_T = TypeVar("_T")
60+
5661

5762
class HandshakeInfo:
5863
def __init__(self, info: Dict[str, Any]):
@@ -74,7 +79,7 @@ def get_jwt_signing_public_key_list(self) -> List[Dict[str, Any]]:
7479
]
7580

7681

77-
class RecipeImplementation(RecipeInterface):
82+
class RecipeImplementation(RecipeInterface): # pylint: disable=too-many-public-methods
7883
def __init__(self, querier: Querier, config: SessionConfig):
7984
super().__init__()
8085
self.querier = querier
@@ -338,6 +343,91 @@ async def get_access_token_lifetime_ms(self, user_context: Dict[str, Any]) -> in
338343
async def get_refresh_token_lifetime_ms(self, user_context: Dict[str, Any]) -> int:
339344
return (await self.get_handshake_info()).refresh_token_validity
340345

346+
async def merge_into_access_token_payload(
347+
self,
348+
session_handle: str,
349+
access_token_payload_update: Dict[str, Any],
350+
user_context: Dict[str, Any],
351+
) -> bool:
352+
session_info = await self.get_session_information(session_handle, user_context)
353+
if session_info is None:
354+
return False
355+
356+
new_access_token_payload = {
357+
**session_info.access_token_payload,
358+
**access_token_payload_update,
359+
}
360+
for k in access_token_payload_update.keys():
361+
if new_access_token_payload[k] is None:
362+
del new_access_token_payload[k]
363+
364+
return await self.update_access_token_payload(
365+
session_handle, new_access_token_payload, user_context
366+
)
367+
368+
async def fetch_and_set_claim(
369+
self,
370+
session_handle: str,
371+
claim: SessionClaim[Any],
372+
user_context: Dict[str, Any],
373+
) -> bool:
374+
session_info = await self.get_session_information(session_handle, user_context)
375+
if session_info is None:
376+
return False
377+
378+
access_token_payload_update = await claim.build(
379+
session_info.user_id, user_context
380+
)
381+
return await self.merge_into_access_token_payload(
382+
session_handle, access_token_payload_update, user_context
383+
)
384+
385+
async def set_claim_value(
386+
self,
387+
session_handle: str,
388+
claim: SessionClaim[Any],
389+
value: Any,
390+
user_context: Dict[str, Any],
391+
):
392+
access_token_payload_update = claim.add_to_payload_({}, value, user_context)
393+
return await self.merge_into_access_token_payload(
394+
session_handle, access_token_payload_update, user_context
395+
)
396+
397+
async def get_claim_value(
398+
self,
399+
session_handle: str,
400+
claim: SessionClaim[Any],
401+
user_context: Dict[str, Any],
402+
):
403+
session_info = await self.get_session_information(session_handle, user_context)
404+
if session_info is None:
405+
raise Exception("Session does not exist")
406+
407+
return claim.get_value_from_payload(
408+
session_info.access_token_payload, user_context
409+
)
410+
411+
async def get_global_claim_validators(
412+
self,
413+
user_id: str,
414+
claim_validators_added_by_other_recipes: List[SessionClaimValidator],
415+
user_context: Dict[str, Any],
416+
) -> Union[SessionClaimValidator, Promise[SessionClaimValidator]]:
417+
# TODO: Implement this
418+
return claim_validators_added_by_other_recipes[0]
419+
420+
async def remove_claim(
421+
self,
422+
session_handle: str,
423+
claim: SessionClaim[Any],
424+
user_context: Dict[str, Any],
425+
) -> bool:
426+
access_token_payload = claim.remove_from_payload_by_merge_({}, user_context)
427+
return await self.merge_into_access_token_payload(
428+
session_handle, access_token_payload, user_context
429+
)
430+
341431
async def regenerate_access_token(
342432
self,
343433
access_token: str,

0 commit comments

Comments
 (0)