Skip to content

feat: Add multitenancy claim #373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion supertokens_python/recipe/emailverification/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ class EmailVerificationClaimClass(BooleanClaim):
def __init__(self):
default_max_age_in_sec = 300

async def fetch_value(user_id: str, user_context: Dict[str, Any]) -> bool:
async def fetch_value(
user_id: str, _tenant_id: str, user_context: Dict[str, Any]
) -> bool:
recipe = EmailVerificationRecipe.get_instance()
email_info = await recipe.get_email_for_user_id(user_id, user_context)

Expand Down
3 changes: 1 addition & 2 deletions supertokens_python/recipe/multitenancy/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,5 @@ async def login_methods_get(


TypeGetAllowedDomainsForTenantId = Callable[
[Union[str, None], Dict[str, Any]],
Awaitable[List[str]],
[Optional[str], Dict[str, Any]], Awaitable[Optional[List[str]]]
]
51 changes: 28 additions & 23 deletions supertokens_python/recipe/multitenancy/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def __init__(
)

self.static_third_party_providers: List[ProviderInput] = []
self.get_allowed_domains_for_tenant_id = (
self.config.get_allowed_domains_for_tenant_id
)

def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool:
return isinstance(err, (TenantDoesNotExistError, RecipeDisabledForTenantError))
Expand Down Expand Up @@ -259,41 +262,43 @@ async def login_methods_get(

class AllowedDomainsClaimClass(PrimitiveArrayClaim[List[str]]):
def __init__(self):
async def fetch_value(_user_id: str, user_context: Dict[str, Any]) -> List[str]:
default_max_age_in_sec = 60 * 60

async def fetch_value(
_: str, tenant_id: str, user_context: Dict[str, Any]
) -> Optional[List[str]]:
recipe = MultitenancyRecipe.get_instance()
tenant_id = (
None # TODO fetch value will be passed with tenant_id as well later
)

if recipe.config.get_allowed_domains_for_tenant_id is None:
return (
[]
) # User did not provide a function to get allowed domains, but is using a validator. So we don't allow any domains by default
if recipe.get_allowed_domains_for_tenant_id is None:
# User did not provide a function to get allowed domains, but is using a validator. So we don't allow any domains by default
return None

domains_res = await recipe.config.get_allowed_domains_for_tenant_id(
return await recipe.get_allowed_domains_for_tenant_id(
tenant_id, user_context
)
return domains_res

super().__init__(
key="st-tenant-domains",
fetch_value=fetch_value,
default_max_age_in_sec=3600,
)
super().__init__("st-t-dmns", fetch_value, default_max_age_in_sec)

def get_value_from_payload(
self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None
) -> Union[List[str], None]:
if self.key not in payload:
self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None
) -> Optional[List[str]]:
_ = user_context

res = payload.get(self.key, {}).get("v")
if res is None:
return []
return super().get_value_from_payload(payload, user_context)
return res

def get_last_refetch_time(
self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None
) -> Union[int, None]:
if self.key not in payload:
self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None
) -> Optional[int]:
_ = user_context

res = payload.get(self.key, {}).get("t")
if res is None:
return get_timestamp_ms()
return super().get_last_refetch_time(payload, user_context)

return res


AllowedDomainsClaim = AllowedDomainsClaimClass()
3 changes: 2 additions & 1 deletion supertokens_python/recipe/session/asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ async def create_new_session_without_request_response(
final_access_token_payload = {**access_token_payload, "iss": issuer}

for claim in claims_added_by_other_recipes:
update = await claim.build(user_id, user_context)
# TODO: Pass tenant id
update = await claim.build(user_id, "pass-tenant-id", user_context)
final_access_token_payload = {**final_access_token_payload, **update}

return await SessionRecipe.get_instance().recipe_implementation.create_new_session(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
self,
key: str,
fetch_value: Callable[
[str, Dict[str, Any]],
[str, str, Dict[str, Any]],
MaybeAwaitable[Optional[bool]],
],
default_max_age_in_sec: Optional[int] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def __init__(
self,
key: str,
fetch_value: Callable[
[str, Dict[str, Any]],
[str, str, Dict[str, Any]],
MaybeAwaitable[Optional[PrimitiveList]],
],
default_max_age_in_sec: Optional[int] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(
self,
key: str,
fetch_value: Callable[
[str, Dict[str, Any]],
[str, str, Dict[str, Any]],
MaybeAwaitable[Optional[Primitive]],
],
default_max_age_in_sec: Optional[int] = None,
Expand Down
7 changes: 4 additions & 3 deletions supertokens_python/recipe/session/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def __init__(
self,
key: str,
fetch_value: Callable[
[str, Dict[str, Any]],
[str, str, Dict[str, Any]],
MaybeAwaitable[Optional[_T]],
],
) -> None:
Expand Down Expand Up @@ -628,11 +628,12 @@ def get_value_from_payload(
"""Gets the value of the claim stored in the payload"""

async def build(
self, user_id: str, user_context: Optional[Dict[str, Any]] = None
self, user_id: str, tenant_id: str, user_context: Optional[Dict[str, Any]] = None
) -> JSONObject:
if user_context is None:
user_context = {}
value = await resolve(self.fetch_value(user_id, user_context))

value = await resolve(self.fetch_value(user_id, tenant_id, user_context))

if value is None:
return {}
Expand Down
10 changes: 8 additions & 2 deletions supertokens_python/recipe/session/recipe_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

from .interfaces import SessionContainer
from supertokens_python.querier import Querier
from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID


class RecipeImplementation(RecipeInterface): # pylint: disable=too-many-public-methods
Expand Down Expand Up @@ -120,7 +121,11 @@ async def validate_claims(
"update_claims_in_payload_if_needed refetching for %s", validator.id
)
value = await resolve(
validator.claim.fetch_value(user_id, user_context)
validator.claim.fetch_value(
user_id,
access_token_payload.get("tId", DEFAULT_TENANT_ID),
user_context,
)
)
log_debug_message(
"update_claims_in_payload_if_needed %s refetch result %s",
Expand Down Expand Up @@ -378,8 +383,9 @@ async def fetch_and_set_claim(
if session_info is None:
return False

# TODO: Pass tenant id
access_token_payload_update = await claim.build(
session_info.user_id, user_context
session_info.user_id, "pass-tenant-id", user_context
)
return await self.merge_into_access_token_payload(
session_handle, access_token_payload_update, user_context
Expand Down
3 changes: 2 additions & 1 deletion supertokens_python/recipe/session/session_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ async def fetch_and_set_claim(
if user_context is None:
user_context = {}

update = await claim.build(self.get_user_id(), user_context)
# TODO: Pass tenant id
update = await claim.build(self.get_user_id(), "pass-tenant-id", user_context)
return await self.merge_into_access_token_payload(update, user_context)

async def set_claim_value(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ async def create_new_session_in_request(
final_access_token_payload = {**access_token_payload, "iss": issuer}

for claim in claims_added_by_other_recipes:
update = await claim.build(user_id, user_context)
# TODO: Pass tenant id
update = await claim.build(user_id, "pass-tenant-id", user_context)
final_access_token_payload = {**final_access_token_payload, **update}

log_debug_message("createNewSession: Access token payload built")
Expand Down
8 changes: 6 additions & 2 deletions supertokens_python/recipe/userroles/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def __init__(self) -> None:
key = "st-perm"
default_max_age_in_sec = 300

async def fetch_value(user_id: str, user_context: Dict[str, Any]) -> List[str]:
async def fetch_value(
user_id: str, _tenant_id: str, user_context: Dict[str, Any]
) -> List[str]:
recipe = UserRolesRecipe.get_instance()

user_roles = await recipe.recipe_implementation.get_roles_for_user(
Expand Down Expand Up @@ -178,7 +180,9 @@ def __init__(self) -> None:
key = "st-role"
default_max_age_in_sec = 300

async def fetch_value(user_id: str, user_context: Dict[str, Any]) -> List[str]:
async def fetch_value(
user_id: str, _tenant_id: str, user_context: Dict[str, Any]
) -> List[str]:
recipe = UserRolesRecipe.get_instance()
res = await recipe.recipe_implementation.get_roles_for_user(
user_id, user_context
Expand Down
2 changes: 1 addition & 1 deletion tests/sessions/claims/test_assert_claims.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def validate(
def should_refetch(self, payload: JSONObject, user_context: Dict[str, Any]):
return False

dummy_claim = PrimitiveClaim("st-claim", lambda _, __: "Hello world")
dummy_claim = PrimitiveClaim("st-claim", lambda _, __, ___: "Hello world")

dummy_claim_validator = DummyClaimValidator(dummy_claim)

Expand Down
6 changes: 3 additions & 3 deletions tests/sessions/claims/test_primitive_array_claim.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,21 @@ def patch_get_timestamp_ms(pac_time_patch: Tuple[MockerFixture, int]):
async def test_primitive_claim(timestamp: int):
claim = PrimitiveArrayClaim("key", sync_fetch_value)
ctx = {}
res = await claim.build("user_id", ctx)
res = await claim.build("user_id", "public", ctx)
assert res == {"key": {"t": timestamp, "v": val}}


async def test_primitive_claim_without_async_fetch_value(timestamp: int):
claim = PrimitiveArrayClaim("key", async_fetch_value)
ctx = {}
res = await claim.build("user_id", ctx)
res = await claim.build("user_id", "public", ctx)
assert res == {"key": {"t": timestamp, "v": val}}


async def test_primitive_claim_matching__add_to_payload():
claim = PrimitiveArrayClaim("key", sync_fetch_value)
ctx = {}
res = await claim.build("user_id", ctx)
res = await claim.build("user_id", "public", ctx)
assert res == claim.add_to_payload_({}, val, {})


Expand Down
6 changes: 3 additions & 3 deletions tests/sessions/claims/test_primitive_claim.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,21 @@ def teardown_function(_):
async def test_primitive_claim(timestamp: int):
claim = PrimitiveClaim("key", sync_fetch_value)
ctx = {}
res = await claim.build("user_id", ctx)
res = await claim.build("user_id", "public", ctx)
assert res == {"key": {"t": timestamp, "v": val}}


async def test_primitive_claim_without_async_fetch_value(timestamp: int):
claim = PrimitiveClaim("key", async_fetch_value)
ctx = {}
res = await claim.build("user_id", ctx)
res = await claim.build("user_id", "public", ctx)
assert res == {"key": {"t": timestamp, "v": val}}


async def test_primitive_claim_matching__add_to_payload():
claim = PrimitiveClaim("key", sync_fetch_value)
ctx = {}
res = await claim.build("user_id", ctx)
res = await claim.build("user_id", "public", ctx)
assert res == claim.add_to_payload_({}, val, {})


Expand Down