Skip to content

feat: Session recipe multitenancy changes #381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ async def handle_sessions_get(
if user_id is None:
raise_bad_input_exception("Missing required parameter 'userId'")

session_handles = await get_all_session_handles_for_user(user_id, user_context)
# Passing tenant id as None sets fetch_across_all_tenants to True
# which is what we want here.
session_handles = await get_all_session_handles_for_user(
user_id, None, user_context
)
sessions: List[Optional[SessionInfo]] = [None for _ in session_handles]

async def call_(i: int, session_handle: str):
Expand Down
2 changes: 2 additions & 0 deletions supertokens_python/recipe/emailpassword/api/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ async def sign_in_post(
user.user_id,
access_token_payload={},
session_data_in_database={},
tenant_id=tenant_id,
user_context=user_context,
)
return SignInPostOkResult(user, session)
Expand Down Expand Up @@ -223,6 +224,7 @@ async def sign_up_post(
user.user_id,
access_token_payload={},
session_data_in_database={},
tenant_id=tenant_id,
user_context=user_context,
)
return SignUpPostOkResult(user, session)
2 changes: 1 addition & 1 deletion supertokens_python/recipe/emailverification/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ async def generate_email_verify_token_post(
email_info = await EmailVerificationRecipe.get_instance().get_email_for_user_id(
user_id, user_context
)
tenant_id = session.get_access_token_payload()["tId"]
tenant_id = session.get_tenant_id()

if isinstance(email_info, EmailDoesNotExistError):
log_debug_message(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ async def consume_code_post(
user.user_id,
{},
{},
tenant_id,
user_context=user_context,
)

Expand Down
14 changes: 14 additions & 0 deletions supertokens_python/recipe/session/access_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from .exceptions import raise_try_refresh_token_exception
from .jwt import ParsedJWTInfo

from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID


def sanitize_string(s: Any) -> Union[str, None]:
if s == "":
Expand Down Expand Up @@ -102,6 +104,10 @@ def get_info_from_access_token(
payload.get("parentRefreshTokenHash1")
)
anti_csrf_token = sanitize_string(payload.get("antiCsrfToken"))
tenant_id = DEFAULT_TENANT_ID

if jwt_info.version >= 4:
tenant_id = sanitize_string(payload.get("tId"))

if anti_csrf_token is None and do_anti_csrf_check:
raise Exception("Access token does not contain the anti-csrf token")
Expand All @@ -120,6 +126,7 @@ def get_info_from_access_token(
"antiCsrfToken": anti_csrf_token,
"expiryTime": expiry_time,
"timeCreated": time_created,
"tenantId": tenant_id,
}
except Exception as e:
log_debug_message(
Expand All @@ -145,6 +152,13 @@ def validate_access_token_structure(payload: Dict[str, Any], version: int) -> No
raise Exception(
"Access token does not contain all the information. Maybe the structure has changed?"
)

if version >= 4:
if not isinstance(payload.get("tId"), str):
raise Exception(
"Access token does not contain all the information. Maybe the structure has changed?"
)

elif (
not isinstance(payload.get("sessionHandle"), str)
or payload.get("userData") is None
Expand Down
23 changes: 17 additions & 6 deletions supertokens_python/recipe/session/asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
)
from ..utils import get_required_claim_validators

from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID

_T = TypeVar("_T")


Expand All @@ -51,6 +53,7 @@ async def create_new_session(
user_id: str,
access_token_payload: Union[Dict[str, Any], None] = None,
session_data_in_database: Union[Dict[str, Any], None] = None,
tenant_id: Optional[str] = None,
user_context: Union[None, Dict[str, Any]] = None,
) -> SessionContainer:
if user_context is None:
Expand All @@ -73,6 +76,7 @@ async def create_new_session(
config,
app_info,
session_data_in_database,
tenant_id or DEFAULT_TENANT_ID,
)


Expand All @@ -81,6 +85,7 @@ async def create_new_session_without_request_response(
access_token_payload: Union[Dict[str, Any], None] = None,
session_data_in_database: Union[Dict[str, Any], None] = None,
disable_anti_csrf: bool = False,
tenant_id: Optional[str] = None,
user_context: Union[None, Dict[str, Any]] = None,
) -> SessionContainer:
if user_context is None:
Expand All @@ -102,15 +107,17 @@ async def create_new_session_without_request_response(
final_access_token_payload = {**access_token_payload, "iss": issuer}

for claim in claims_added_by_other_recipes:
# TODO: Pass tenant id
update = await claim.build(user_id, "pass-tenant-id", user_context)
update = await claim.build(
user_id, tenant_id or DEFAULT_TENANT_ID, user_context
)
final_access_token_payload = {**final_access_token_payload, **update}

return await SessionRecipe.get_instance().recipe_implementation.create_new_session(
user_id,
final_access_token_payload,
session_data_in_database,
disable_anti_csrf,
tenant_id or DEFAULT_TENANT_ID,
user_context=user_context,
)

Expand Down Expand Up @@ -421,22 +428,26 @@ async def revoke_session(


async def revoke_all_sessions_for_user(
user_id: str, user_context: Union[None, Dict[str, Any]] = None
user_id: str,
tenant_id: Optional[str] = None,
user_context: Union[None, Dict[str, Any]] = None,
) -> List[str]:
if user_context is None:
user_context = {}
return await SessionRecipe.get_instance().recipe_implementation.revoke_all_sessions_for_user(
user_id, user_context
user_id, tenant_id or DEFAULT_TENANT_ID, tenant_id is None, user_context
)


async def get_all_session_handles_for_user(
user_id: str, user_context: Union[None, Dict[str, Any]] = None
user_id: str,
tenant_id: Optional[str] = None,
user_context: Union[None, Dict[str, Any]] = None,
) -> List[str]:
if user_context is None:
user_context = {}
return await SessionRecipe.get_instance().recipe_implementation.get_all_session_handles_for_user(
user_id, user_context
user_id, tenant_id or DEFAULT_TENANT_ID, tenant_id is None, user_context
)


Expand Down
46 changes: 35 additions & 11 deletions supertokens_python/recipe/session/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,17 @@


class SessionObj:
def __init__(self, handle: str, user_id: str, user_data_in_jwt: Dict[str, Any]):
def __init__(
self,
handle: str,
user_id: str,
user_data_in_jwt: Dict[str, Any],
tenant_id: str,
):
self.handle = handle
self.user_id = user_id
self.user_data_in_jwt = user_data_in_jwt
self.tenant_id = tenant_id


class AccessTokenObj:
Expand All @@ -69,15 +76,17 @@ def __init__(
expiry: int,
custom_claims_in_access_token_payload: Dict[str, Any],
time_created: int,
tenant_id: str,
):
self.session_handle: str = session_handle
self.user_id: str = user_id
self.session_data_in_database: Dict[str, Any] = session_data_in_database
self.expiry: int = expiry
self.custom_claims_in_access_token_payload: Dict[
str, Any
] = custom_claims_in_access_token_payload
self.time_created: int = time_created
self.session_handle = session_handle
self.user_id = user_id
self.session_data_in_database = session_data_in_database
self.expiry = expiry
self.custom_claims_in_access_token_payload = (
custom_claims_in_access_token_payload
)
self.time_created = time_created
self.tenant_id = tenant_id


class ReqResInfo:
Expand Down Expand Up @@ -137,6 +146,7 @@ async def create_new_session(
access_token_payload: Optional[Dict[str, Any]],
session_data_in_database: Optional[Dict[str, Any]],
disable_anti_csrf: Optional[bool],
tenant_id: str,
user_context: Dict[str, Any],
) -> SessionContainer:
pass
Expand Down Expand Up @@ -206,13 +216,21 @@ async def revoke_session(

@abstractmethod
async def revoke_all_sessions_for_user(
self, user_id: str, user_context: Dict[str, Any]
self,
user_id: str,
tenant_id: str,
revoke_across_all_tenants: bool,
user_context: Dict[str, Any],
) -> List[str]:
pass

@abstractmethod
async def get_all_session_handles_for_user(
self, user_id: str, user_context: Dict[str, Any]
self,
user_id: str,
tenant_id: str,
fetch_across_all_tenants: bool,
user_context: Dict[str, Any],
) -> List[str]:
pass

Expand Down Expand Up @@ -383,6 +401,7 @@ def __init__(
user_data_in_access_token: Optional[Dict[str, Any]],
req_res_info: Optional[ReqResInfo],
access_token_updated: bool,
tenant_id: str,
):
self.recipe_implementation = recipe_implementation
self.config = config
Expand All @@ -395,6 +414,7 @@ def __init__(
self.user_data_in_access_token = user_data_in_access_token
self.req_res_info: Optional[ReqResInfo] = req_res_info
self.access_token_updated = access_token_updated
self.tenant_id = tenant_id

self.response_mutators: List[ResponseMutator] = []

Expand Down Expand Up @@ -436,6 +456,10 @@ async def merge_into_access_token_payload(
def get_user_id(self, user_context: Optional[Dict[str, Any]] = None) -> str:
pass

@abstractmethod
def get_tenant_id(self, user_context: Optional[Dict[str, Any]] = None) -> str:
pass

@abstractmethod
def get_access_token_payload(
self, user_context: Optional[Dict[str, Any]] = None
Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/recipe/session/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def get_apis_handled(self) -> List[APIHandled]:
async def handle_api_request(
self,
request_id: str,
tenant_id: Optional[str],
tenant_id: str,
request: BaseRequest,
path: NormalisedURLPath,
method: str,
Expand Down
29 changes: 23 additions & 6 deletions supertokens_python/recipe/session/recipe_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ async def create_new_session(
access_token_payload: Optional[Dict[str, Any]],
session_data_in_database: Optional[Dict[str, Any]],
disable_anti_csrf: Optional[bool],
tenant_id: str,
user_context: Dict[str, Any],
) -> SessionContainer:
log_debug_message("createNewSession: Started")
Expand All @@ -74,6 +75,7 @@ async def create_new_session(
disable_anti_csrf is True,
access_token_payload,
session_data_in_database,
tenant_id,
)
log_debug_message("createNewSession: Finished")

Expand All @@ -95,6 +97,7 @@ async def create_new_session(
payload,
None,
True,
tenant_id,
)

return new_session
Expand Down Expand Up @@ -262,6 +265,7 @@ async def get_session(
payload,
None,
access_token_updated,
response.session.tenant_id,
)

return session
Expand Down Expand Up @@ -312,6 +316,7 @@ async def refresh_session(
user_data_in_access_token=payload,
req_res_info=None,
access_token_updated=True,
tenant_id=payload["tId"],
)

return session
Expand All @@ -322,14 +327,26 @@ async def revoke_session(
return await session_functions.revoke_session(self, session_handle)

async def revoke_all_sessions_for_user(
self, user_id: str, user_context: Dict[str, Any]
self,
user_id: str,
tenant_id: Optional[str],
revoke_across_all_tenants: bool,
user_context: Dict[str, Any],
) -> List[str]:
return await session_functions.revoke_all_sessions_for_user(self, user_id)
return await session_functions.revoke_all_sessions_for_user(
self, user_id, tenant_id, revoke_across_all_tenants
)

async def get_all_session_handles_for_user(
self, user_id: str, user_context: Dict[str, Any]
self,
user_id: str,
tenant_id: Optional[str],
fetch_across_all_tenants: bool,
user_context: Dict[str, Any],
) -> List[str]:
return await session_functions.get_all_session_handles_for_user(self, user_id)
return await session_functions.get_all_session_handles_for_user(
self, user_id, tenant_id, fetch_across_all_tenants
)

async def revoke_multiple_sessions(
self, session_handles: List[str], user_context: Dict[str, Any]
Expand Down Expand Up @@ -383,9 +400,8 @@ async def fetch_and_set_claim(
if session_info is None:
return False

# TODO: Pass tenant id
access_token_payload_update = await claim.build(
session_info.user_id, "pass-tenant-id", user_context
session_info.user_id, session_info.tenant_id, user_context
)
return await self.merge_into_access_token_payload(
session_handle, access_token_payload_update, user_context
Expand Down Expand Up @@ -463,5 +479,6 @@ async def regenerate_access_token(
response["session"]["handle"],
response["session"]["userId"],
response["session"]["userDataInJWT"],
response["session"]["tenantId"],
)
return RegenerateAccessTokenOkResult(session, access_token_obj)
3 changes: 3 additions & 0 deletions supertokens_python/recipe/session/session_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ async def update_session_data_in_database(
def get_user_id(self, user_context: Union[Dict[str, Any], None] = None) -> str:
return self.user_id

def get_tenant_id(self, user_context: Union[Dict[str, Any], None] = None) -> str:
return self.tenant_id

def get_access_token_payload(
self, user_context: Union[Dict[str, Any], None] = None
) -> Dict[str, Any]:
Expand Down
Loading