Skip to content

Commit abba205

Browse files
committed
feat: Add primitive claim and its tests for session grants
1 parent 382d417 commit abba205

File tree

7 files changed

+590
-1
lines changed

7 files changed

+590
-1
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
15+
from .primitive_claim import PrimitiveClaim
16+
17+
18+
class BooleanClaim(PrimitiveClaim):
19+
pass
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
15+
import time
16+
from typing import Any, Dict, TypeVar, Union
17+
18+
from ..interfaces import JSONObject, JSONPrimitive, SessionClaim, SessionClaimValidator
19+
20+
_T = TypeVar("_T")
21+
22+
23+
class PrimitiveClaim(SessionClaim[JSONPrimitive]):
24+
def __init__(self, key: str, fetch_value: Any) -> None:
25+
super().__init__(key)
26+
self.fetch_value = fetch_value
27+
28+
claim = self
29+
30+
def has_value(
31+
val: JSONPrimitive, id_: Union[str, None] = None
32+
) -> SessionClaimValidator:
33+
class HasValueSCV(SessionClaimValidator):
34+
def __init__(self):
35+
self.claim = claim
36+
self.id_ = id_ or claim.key
37+
38+
def should_refetch(
39+
self,
40+
payload: JSONObject,
41+
user_context: Union[Dict[str, Any], None] = None,
42+
):
43+
return claim.get_value_from_payload(payload, user_context) is None
44+
45+
def validate(
46+
self,
47+
payload: JSONObject,
48+
user_context: Union[Dict[str, Any], None] = None,
49+
):
50+
claim_val = claim.get_value_from_payload(payload, user_context)
51+
is_valid = claim_val == val
52+
if is_valid:
53+
return {"isValid": True}
54+
55+
return {
56+
"isValid": False,
57+
"reason": {
58+
"message": "wrong value",
59+
"expectedValue": val,
60+
"actualValue": claim_val,
61+
},
62+
}
63+
64+
scv = HasValueSCV()
65+
return scv
66+
67+
def has_fresh_value(
68+
val: JSONPrimitive, max_age_in_sec: int, id_: Union[str, None] = None
69+
) -> SessionClaimValidator:
70+
class HasFreshValueSCV(SessionClaimValidator):
71+
def __init__(self):
72+
self.claim = claim
73+
self.id_ = id_ or (claim.key + "-fresh-val")
74+
75+
def should_refetch(
76+
self,
77+
payload: JSONObject,
78+
user_context: Union[Dict[str, Any], None] = None,
79+
):
80+
# (claim value is None) OR (value has expired)
81+
return (
82+
claim.get_value_from_payload(payload, user_context) is None
83+
) or (payload[claim.key]["t"] < time.time() - max_age_in_sec * 1000)
84+
85+
def validate(
86+
self,
87+
payload: JSONObject,
88+
user_context: Union[Dict[str, Any], None] = None,
89+
):
90+
claim_val = claim.get_value_from_payload(payload, user_context)
91+
if claim_val != val:
92+
return {
93+
"isValid": False,
94+
"reason": {
95+
"message": "wrong value",
96+
"expectedValue": val,
97+
"actualValue": claim_val,
98+
},
99+
}
100+
101+
age_in_sec = (time.time() - payload[claim.key]["t"]) / 1000
102+
if age_in_sec > max_age_in_sec:
103+
return {
104+
"isValid": False,
105+
"reason": {
106+
"message": "expired",
107+
"ageInSeconds": age_in_sec,
108+
"maxAgeInSeconds": max_age_in_sec,
109+
},
110+
}
111+
112+
return {"isValid": True}
113+
114+
scv = HasFreshValueSCV()
115+
return scv
116+
117+
class Validators:
118+
def __init__(self) -> None:
119+
self.has_value = has_value
120+
self.has_fresh_value = has_fresh_value
121+
122+
self.validators = Validators()
123+
124+
def add_to_payload_(
125+
self,
126+
payload: Any,
127+
value: JSONPrimitive,
128+
user_context: Union[Dict[str, Any], None] = None,
129+
) -> JSONObject:
130+
payload[self.key] = {"v": value, "t": time.time()}
131+
_ = user_context
132+
133+
return payload
134+
135+
def remove_from_payload_by_merge_(
136+
self, payload: JSONObject, user_context: Dict[str, Any]
137+
) -> JSONObject:
138+
_ = user_context
139+
140+
payload[self.key] = None
141+
return payload
142+
143+
def remove_from_payload(
144+
self, payload: JSONObject, user_context: Dict[str, Any]
145+
) -> JSONObject:
146+
_ = user_context
147+
del payload[self.key]
148+
return payload
149+
150+
def get_value_from_payload(
151+
self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None
152+
) -> Union[JSONPrimitive, None]:
153+
_ = user_context
154+
return payload.get(self.key, {}).get("v")
155+
156+
def get_last_refetch_time(
157+
self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None
158+
) -> Union[int, None]:
159+
_ = user_context
160+
return payload.get(self.key, {}).get("t")
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
15+
from . import interfaces
16+
from .claim_base_classes import boolean_claim, primitive_claim
17+
18+
SessionClaim = interfaces.SessionClaim
19+
BooleanClaim = boolean_claim.BooleanClaim
20+
PrimitiveClaim = primitive_claim.PrimitiveClaim

supertokens_python/recipe/session/interfaces.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from __future__ import annotations
1515

1616
from abc import ABC, abstractmethod
17-
from typing import TYPE_CHECKING, Any, Dict, List, Union
17+
from asyncio import iscoroutinefunction
18+
from typing import TYPE_CHECKING, Any, Dict, Generic, List, TypeVar, Union
1819

1920
from supertokens_python.async_to_sync_wrapper import sync
2021
from supertokens_python.types import APIResponse, GeneralErrorResponse
@@ -327,3 +328,88 @@ def sync_update_session_data(
327328
# This is there so that we can do session["..."] to access some of the members of this class
328329
def __getitem__(self, item: str):
329330
return getattr(self, item)
331+
332+
333+
# Session claims:
334+
_T = TypeVar("_T")
335+
JSONObject = Dict[str, Any]
336+
337+
338+
JSONPrimitive = Union[str, int, bool, None, Dict[str, Any]]
339+
340+
FetchValueReturnType = Union[_T, None]
341+
342+
343+
class SessionClaim(ABC, Generic[_T]):
344+
def __init__(self, key: str) -> None:
345+
self.key = key
346+
347+
# fetchValue(userId: string, userContext: any): Promise<T | undefined> | T | undefined;
348+
# Union[Promise[FetchValueReturnType], FetchValueReturnType]
349+
def fetch_value(
350+
self, user_id: str, user_context: Union[Dict[str, Any], None] = None
351+
) -> Any:
352+
pass
353+
354+
@abstractmethod
355+
def add_to_payload_(
356+
self,
357+
payload: JSONObject,
358+
value: _T,
359+
user_context: Union[Dict[str, Any], None] = None,
360+
) -> JSONObject:
361+
"""Saves the provided value into the payload, by cloning and updating the entire object"""
362+
363+
@abstractmethod
364+
def remove_from_payload_by_merge_(
365+
self, payload: JSONObject, user_context: Dict[str, Any]
366+
) -> JSONObject:
367+
"""Removes the claim from the payload by setting it to null, so mergeIntoAccessTokenPayload clears it"""
368+
369+
@abstractmethod
370+
def remove_from_payload(
371+
self, payload: JSONObject, user_context: Dict[str, Any]
372+
) -> JSONObject:
373+
"""Gets the value of the claim stored in the payload
374+
375+
Returns:
376+
JSONObject: Claim value
377+
"""
378+
379+
@abstractmethod
380+
def get_value_from_payload(
381+
self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None
382+
) -> Union[_T, None]:
383+
pass
384+
385+
async def build(
386+
self, user_id: str, user_context: Union[Dict[str, Any], None] = None
387+
) -> JSONObject:
388+
if iscoroutinefunction(self.fetch_value):
389+
value = await self.fetch_value(user_id, user_context)
390+
else:
391+
value = self.fetch_value( # pylint: disable=assignment-from-no-return
392+
user_id, user_context
393+
)
394+
395+
if value is None:
396+
return {}
397+
398+
return self.add_to_payload_({}, value, user_context)
399+
400+
401+
class SessionClaimValidator(ABC):
402+
id: str
403+
claim: SessionClaim[Any]
404+
405+
@abstractmethod
406+
def should_refetch(
407+
self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None
408+
) -> Any:
409+
pass
410+
411+
@abstractmethod
412+
def validate(
413+
self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None
414+
) -> Any:
415+
pass

tests/sessions/claims/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import nest_asyncio # type: ignore
2+
3+
nest_asyncio.apply() # type: ignore

0 commit comments

Comments
 (0)