Skip to content

Commit 86b02a9

Browse files
Merge pull request #373 from supertokens/feat/multitenancy-claims
feat: Add multitenancy claim
2 parents 98e0a07 + a0b8529 commit 86b02a9

File tree

15 files changed

+66
-46
lines changed

15 files changed

+66
-46
lines changed

supertokens_python/recipe/emailverification/recipe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,9 @@ class EmailVerificationClaimClass(BooleanClaim):
312312
def __init__(self):
313313
default_max_age_in_sec = 300
314314

315-
async def fetch_value(user_id: str, user_context: Dict[str, Any]) -> bool:
315+
async def fetch_value(
316+
user_id: str, _tenant_id: str, user_context: Dict[str, Any]
317+
) -> bool:
316318
recipe = EmailVerificationRecipe.get_instance()
317319
email_info = await recipe.get_email_for_user_id(user_id, user_context)
318320

supertokens_python/recipe/multitenancy/interfaces.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,5 @@ async def login_methods_get(
336336

337337

338338
TypeGetAllowedDomainsForTenantId = Callable[
339-
[Union[str, None], Dict[str, Any]],
340-
Awaitable[List[str]],
339+
[Optional[str], Dict[str, Any]], Awaitable[Optional[List[str]]]
341340
]

supertokens_python/recipe/multitenancy/recipe.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ def __init__(
100100
)
101101

102102
self.static_third_party_providers: List[ProviderInput] = []
103+
self.get_allowed_domains_for_tenant_id = (
104+
self.config.get_allowed_domains_for_tenant_id
105+
)
103106

104107
def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool:
105108
return isinstance(err, (TenantDoesNotExistError, RecipeDisabledForTenantError))
@@ -259,41 +262,43 @@ async def login_methods_get(
259262

260263
class AllowedDomainsClaimClass(PrimitiveArrayClaim[List[str]]):
261264
def __init__(self):
262-
async def fetch_value(_user_id: str, user_context: Dict[str, Any]) -> List[str]:
265+
default_max_age_in_sec = 60 * 60
266+
267+
async def fetch_value(
268+
_: str, tenant_id: str, user_context: Dict[str, Any]
269+
) -> Optional[List[str]]:
263270
recipe = MultitenancyRecipe.get_instance()
264-
tenant_id = (
265-
None # TODO fetch value will be passed with tenant_id as well later
266-
)
267271

268-
if recipe.config.get_allowed_domains_for_tenant_id is None:
269-
return (
270-
[]
271-
) # User did not provide a function to get allowed domains, but is using a validator. So we don't allow any domains by default
272+
if recipe.get_allowed_domains_for_tenant_id is None:
273+
# User did not provide a function to get allowed domains, but is using a validator. So we don't allow any domains by default
274+
return None
272275

273-
domains_res = await recipe.config.get_allowed_domains_for_tenant_id(
276+
return await recipe.get_allowed_domains_for_tenant_id(
274277
tenant_id, user_context
275278
)
276-
return domains_res
277279

278-
super().__init__(
279-
key="st-tenant-domains",
280-
fetch_value=fetch_value,
281-
default_max_age_in_sec=3600,
282-
)
280+
super().__init__("st-t-dmns", fetch_value, default_max_age_in_sec)
283281

284282
def get_value_from_payload(
285-
self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None
286-
) -> Union[List[str], None]:
287-
if self.key not in payload:
283+
self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None
284+
) -> Optional[List[str]]:
285+
_ = user_context
286+
287+
res = payload.get(self.key, {}).get("v")
288+
if res is None:
288289
return []
289-
return super().get_value_from_payload(payload, user_context)
290+
return res
290291

291292
def get_last_refetch_time(
292-
self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None
293-
) -> Union[int, None]:
294-
if self.key not in payload:
293+
self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None
294+
) -> Optional[int]:
295+
_ = user_context
296+
297+
res = payload.get(self.key, {}).get("t")
298+
if res is None:
295299
return get_timestamp_ms()
296-
return super().get_last_refetch_time(payload, user_context)
300+
301+
return res
297302

298303

299304
AllowedDomainsClaim = AllowedDomainsClaimClass()

supertokens_python/recipe/session/asyncio/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ async def create_new_session_without_request_response(
102102
final_access_token_payload = {**access_token_payload, "iss": issuer}
103103

104104
for claim in claims_added_by_other_recipes:
105-
update = await claim.build(user_id, user_context)
105+
# TODO: Pass tenant id
106+
update = await claim.build(user_id, "pass-tenant-id", user_context)
106107
final_access_token_payload = {**final_access_token_payload, **update}
107108

108109
return await SessionRecipe.get_instance().recipe_implementation.create_new_session(

supertokens_python/recipe/session/claim_base_classes/boolean_claim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
self,
3232
key: str,
3333
fetch_value: Callable[
34-
[str, Dict[str, Any]],
34+
[str, str, Dict[str, Any]],
3535
MaybeAwaitable[Optional[bool]],
3636
],
3737
default_max_age_in_sec: Optional[int] = None,

supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def __init__(
267267
self,
268268
key: str,
269269
fetch_value: Callable[
270-
[str, Dict[str, Any]],
270+
[str, str, Dict[str, Any]],
271271
MaybeAwaitable[Optional[PrimitiveList]],
272272
],
273273
default_max_age_in_sec: Optional[int] = None,

supertokens_python/recipe/session/claim_base_classes/primitive_claim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __init__(
132132
self,
133133
key: str,
134134
fetch_value: Callable[
135-
[str, Dict[str, Any]],
135+
[str, str, Dict[str, Any]],
136136
MaybeAwaitable[Optional[Primitive]],
137137
],
138138
default_max_age_in_sec: Optional[int] = None,

supertokens_python/recipe/session/interfaces.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ def __init__(
585585
self,
586586
key: str,
587587
fetch_value: Callable[
588-
[str, Dict[str, Any]],
588+
[str, str, Dict[str, Any]],
589589
MaybeAwaitable[Optional[_T]],
590590
],
591591
) -> None:
@@ -628,11 +628,12 @@ def get_value_from_payload(
628628
"""Gets the value of the claim stored in the payload"""
629629

630630
async def build(
631-
self, user_id: str, user_context: Optional[Dict[str, Any]] = None
631+
self, user_id: str, tenant_id: str, user_context: Optional[Dict[str, Any]] = None
632632
) -> JSONObject:
633633
if user_context is None:
634634
user_context = {}
635-
value = await resolve(self.fetch_value(user_id, user_context))
635+
636+
value = await resolve(self.fetch_value(user_id, tenant_id, user_context))
636637

637638
if value is None:
638639
return {}

supertokens_python/recipe/session/recipe_implementation.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
from .interfaces import SessionContainer
5050
from supertokens_python.querier import Querier
51+
from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID
5152

5253

5354
class RecipeImplementation(RecipeInterface): # pylint: disable=too-many-public-methods
@@ -120,7 +121,11 @@ async def validate_claims(
120121
"update_claims_in_payload_if_needed refetching for %s", validator.id
121122
)
122123
value = await resolve(
123-
validator.claim.fetch_value(user_id, user_context)
124+
validator.claim.fetch_value(
125+
user_id,
126+
access_token_payload.get("tId", DEFAULT_TENANT_ID),
127+
user_context,
128+
)
124129
)
125130
log_debug_message(
126131
"update_claims_in_payload_if_needed %s refetch result %s",
@@ -378,8 +383,9 @@ async def fetch_and_set_claim(
378383
if session_info is None:
379384
return False
380385

386+
# TODO: Pass tenant id
381387
access_token_payload_update = await claim.build(
382-
session_info.user_id, user_context
388+
session_info.user_id, "pass-tenant-id", user_context
383389
)
384390
return await self.merge_into_access_token_payload(
385391
session_handle, access_token_payload_update, user_context

supertokens_python/recipe/session/session_class.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ async def fetch_and_set_claim(
220220
if user_context is None:
221221
user_context = {}
222222

223-
update = await claim.build(self.get_user_id(), user_context)
223+
# TODO: Pass tenant id
224+
update = await claim.build(self.get_user_id(), "pass-tenant-id", user_context)
224225
return await self.merge_into_access_token_payload(update, user_context)
225226

226227
async def set_claim_value(

supertokens_python/recipe/session/session_request_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,8 @@ async def create_new_session_in_request(
238238
final_access_token_payload = {**access_token_payload, "iss": issuer}
239239

240240
for claim in claims_added_by_other_recipes:
241-
update = await claim.build(user_id, user_context)
241+
# TODO: Pass tenant id
242+
update = await claim.build(user_id, "pass-tenant-id", user_context)
242243
final_access_token_payload = {**final_access_token_payload, **update}
243244

244245
log_debug_message("createNewSession: Access token payload built")

supertokens_python/recipe/userroles/recipe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,9 @@ def __init__(self) -> None:
145145
key = "st-perm"
146146
default_max_age_in_sec = 300
147147

148-
async def fetch_value(user_id: str, user_context: Dict[str, Any]) -> List[str]:
148+
async def fetch_value(
149+
user_id: str, _tenant_id: str, user_context: Dict[str, Any]
150+
) -> List[str]:
149151
recipe = UserRolesRecipe.get_instance()
150152

151153
user_roles = await recipe.recipe_implementation.get_roles_for_user(
@@ -178,7 +180,9 @@ def __init__(self) -> None:
178180
key = "st-role"
179181
default_max_age_in_sec = 300
180182

181-
async def fetch_value(user_id: str, user_context: Dict[str, Any]) -> List[str]:
183+
async def fetch_value(
184+
user_id: str, _tenant_id: str, user_context: Dict[str, Any]
185+
) -> List[str]:
182186
recipe = UserRolesRecipe.get_instance()
183187
res = await recipe.recipe_implementation.get_roles_for_user(
184188
user_id, user_context

tests/sessions/claims/test_assert_claims.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ async def validate(
108108
def should_refetch(self, payload: JSONObject, user_context: Dict[str, Any]):
109109
return False
110110

111-
dummy_claim = PrimitiveClaim("st-claim", lambda _, __: "Hello world")
111+
dummy_claim = PrimitiveClaim("st-claim", lambda _, __, ___: "Hello world")
112112

113113
dummy_claim_validator = DummyClaimValidator(dummy_claim)
114114

tests/sessions/claims/test_primitive_array_claim.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,21 +57,21 @@ def patch_get_timestamp_ms(pac_time_patch: Tuple[MockerFixture, int]):
5757
async def test_primitive_claim(timestamp: int):
5858
claim = PrimitiveArrayClaim("key", sync_fetch_value)
5959
ctx = {}
60-
res = await claim.build("user_id", ctx)
60+
res = await claim.build("user_id", "public", ctx)
6161
assert res == {"key": {"t": timestamp, "v": val}}
6262

6363

6464
async def test_primitive_claim_without_async_fetch_value(timestamp: int):
6565
claim = PrimitiveArrayClaim("key", async_fetch_value)
6666
ctx = {}
67-
res = await claim.build("user_id", ctx)
67+
res = await claim.build("user_id", "public", ctx)
6868
assert res == {"key": {"t": timestamp, "v": val}}
6969

7070

7171
async def test_primitive_claim_matching__add_to_payload():
7272
claim = PrimitiveArrayClaim("key", sync_fetch_value)
7373
ctx = {}
74-
res = await claim.build("user_id", ctx)
74+
res = await claim.build("user_id", "public", ctx)
7575
assert res == claim.add_to_payload_({}, val, {})
7676

7777

tests/sessions/claims/test_primitive_claim.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,21 @@ def teardown_function(_):
2323
async def test_primitive_claim(timestamp: int):
2424
claim = PrimitiveClaim("key", sync_fetch_value)
2525
ctx = {}
26-
res = await claim.build("user_id", ctx)
26+
res = await claim.build("user_id", "public", ctx)
2727
assert res == {"key": {"t": timestamp, "v": val}}
2828

2929

3030
async def test_primitive_claim_without_async_fetch_value(timestamp: int):
3131
claim = PrimitiveClaim("key", async_fetch_value)
3232
ctx = {}
33-
res = await claim.build("user_id", ctx)
33+
res = await claim.build("user_id", "public", ctx)
3434
assert res == {"key": {"t": timestamp, "v": val}}
3535

3636

3737
async def test_primitive_claim_matching__add_to_payload():
3838
claim = PrimitiveClaim("key", sync_fetch_value)
3939
ctx = {}
40-
res = await claim.build("user_id", ctx)
40+
res = await claim.build("user_id", "public", ctx)
4141
assert res == claim.add_to_payload_({}, val, {})
4242

4343

0 commit comments

Comments
 (0)