Skip to content

Commit bff0cbd

Browse files
committed
feat: Session multitenancy changes
1 parent 756bf40 commit bff0cbd

File tree

13 files changed

+108
-37
lines changed

13 files changed

+108
-37
lines changed

supertokens_python/recipe/emailpassword/api/implementation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ async def sign_in_post(
182182
user.user_id,
183183
access_token_payload={},
184184
session_data_in_database={},
185+
tenant_id=tenant_id,
185186
user_context=user_context,
186187
)
187188
return SignInPostOkResult(user, session)
@@ -222,6 +223,7 @@ async def sign_up_post(
222223
user.user_id,
223224
access_token_payload={},
224225
session_data_in_database={},
226+
tenant_id=tenant_id,
225227
user_context=user_context,
226228
)
227229
return SignUpPostOkResult(user, session)

supertokens_python/recipe/emailverification/recipe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ async def generate_email_verify_token_post(
412412
email_info = await EmailVerificationRecipe.get_instance().get_email_for_user_id(
413413
user_id, user_context
414414
)
415-
tenant_id = session.get_access_token_payload()["tid"]
415+
tenant_id = session.get_tenant_id()
416416

417417
if isinstance(email_info, EmailDoesNotExistError):
418418
log_debug_message(

supertokens_python/recipe/passwordless/api/implementation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ async def consume_code_post(
303303
user.user_id,
304304
{},
305305
{},
306+
tenant_id,
306307
user_context=user_context,
307308
)
308309

supertokens_python/recipe/session/access_token.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from .exceptions import raise_try_refresh_token_exception
2525
from .jwt import ParsedJWTInfo
2626

27+
from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID
28+
2729

2830
def sanitize_string(s: Any) -> Union[str, None]:
2931
if s == "":
@@ -102,6 +104,10 @@ def get_info_from_access_token(
102104
payload.get("parentRefreshTokenHash1")
103105
)
104106
anti_csrf_token = sanitize_string(payload.get("antiCsrfToken"))
107+
tenant_id = DEFAULT_TENANT_ID
108+
109+
if jwt_info.version >= 4:
110+
tenant_id = sanitize_string(payload.get("tId"))
105111

106112
if anti_csrf_token is None and do_anti_csrf_check:
107113
raise Exception("Access token does not contain the anti-csrf token")
@@ -120,6 +126,7 @@ def get_info_from_access_token(
120126
"antiCsrfToken": anti_csrf_token,
121127
"expiryTime": expiry_time,
122128
"timeCreated": time_created,
129+
"tenantId": tenant_id,
123130
}
124131
except Exception as e:
125132
log_debug_message(
@@ -145,6 +152,11 @@ def validate_access_token_structure(payload: Dict[str, Any], version: int) -> No
145152
raise Exception(
146153
"Access token does not contain all the information. Maybe the structure has changed?"
147154
)
155+
156+
if version >= 4:
157+
if not isinstance(payload.get("tId"), str):
158+
raise Exception("Access token does not contain all the information. Maybe the structure has changed?")
159+
148160
elif (
149161
not isinstance(payload.get("sessionHandle"), str)
150162
or payload.get("userData") is None

supertokens_python/recipe/session/asyncio/__init__.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
)
4444
from ..utils import get_required_claim_validators
4545

46+
from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID
47+
4648
_T = TypeVar("_T")
4749

4850

@@ -51,6 +53,7 @@ async def create_new_session(
5153
user_id: str,
5254
access_token_payload: Union[Dict[str, Any], None] = None,
5355
session_data_in_database: Union[Dict[str, Any], None] = None,
56+
tenant_id: Optional[str] = None,
5457
user_context: Union[None, Dict[str, Any]] = None,
5558
) -> SessionContainer:
5659
if user_context is None:
@@ -73,6 +76,7 @@ async def create_new_session(
7376
config,
7477
app_info,
7578
session_data_in_database,
79+
tenant_id or DEFAULT_TENANT_ID,
7680
)
7781

7882

@@ -81,6 +85,7 @@ async def create_new_session_without_request_response(
8185
access_token_payload: Union[Dict[str, Any], None] = None,
8286
session_data_in_database: Union[Dict[str, Any], None] = None,
8387
disable_anti_csrf: bool = False,
88+
tenant_id: Optional[str] = None,
8489
user_context: Union[None, Dict[str, Any]] = None,
8590
) -> SessionContainer:
8691
if user_context is None:
@@ -102,7 +107,6 @@ async def create_new_session_without_request_response(
102107
final_access_token_payload = {**access_token_payload, "iss": issuer}
103108

104109
for claim in claims_added_by_other_recipes:
105-
# TODO: Pass tenant id
106110
update = await claim.build(user_id, "pass-tenant-id", user_context)
107111
final_access_token_payload = {**final_access_token_payload, **update}
108112

@@ -111,6 +115,7 @@ async def create_new_session_without_request_response(
111115
final_access_token_payload,
112116
session_data_in_database,
113117
disable_anti_csrf,
118+
tenant_id or DEFAULT_TENANT_ID,
114119
user_context=user_context,
115120
)
116121

@@ -421,22 +426,23 @@ async def revoke_session(
421426

422427

423428
async def revoke_all_sessions_for_user(
424-
user_id: str, user_context: Union[None, Dict[str, Any]] = None
429+
user_id: str, tenant_id: Optional[str], user_context: Union[None, Dict[str, Any]] = None
425430
) -> List[str]:
426431
if user_context is None:
427432
user_context = {}
428433
return await SessionRecipe.get_instance().recipe_implementation.revoke_all_sessions_for_user(
429-
user_id, user_context
434+
user_id, tenant_id or DEFAULT_TENANT_ID,
435+
tenant_id is None, user_context
430436
)
431437

432438

433439
async def get_all_session_handles_for_user(
434-
user_id: str, user_context: Union[None, Dict[str, Any]] = None
440+
user_id: str, tenant_id: Optional[str], user_context: Union[None, Dict[str, Any]] = None
435441
) -> List[str]:
436442
if user_context is None:
437443
user_context = {}
438444
return await SessionRecipe.get_instance().recipe_implementation.get_all_session_handles_for_user(
439-
user_id, user_context
445+
user_id, tenant_id or DEFAULT_TENANT_ID, tenant_id is None, user_context
440446
)
441447

442448

supertokens_python/recipe/session/interfaces.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@
4141

4242

4343
class SessionObj:
44-
def __init__(self, handle: str, user_id: str, user_data_in_jwt: Dict[str, Any]):
44+
def __init__(self, handle: str, user_id: str, user_data_in_jwt: Dict[str, Any], tenant_id: str):
4545
self.handle = handle
4646
self.user_id = user_id
4747
self.user_data_in_jwt = user_data_in_jwt
48+
self.tenant_id = tenant_id
4849

4950

5051
class AccessTokenObj:
@@ -69,15 +70,15 @@ def __init__(
6970
expiry: int,
7071
custom_claims_in_access_token_payload: Dict[str, Any],
7172
time_created: int,
73+
tenant_id: str,
7274
):
73-
self.session_handle: str = session_handle
74-
self.user_id: str = user_id
75-
self.session_data_in_database: Dict[str, Any] = session_data_in_database
76-
self.expiry: int = expiry
77-
self.custom_claims_in_access_token_payload: Dict[
78-
str, Any
79-
] = custom_claims_in_access_token_payload
80-
self.time_created: int = time_created
75+
self.session_handle = session_handle
76+
self.user_id = user_id
77+
self.session_data_in_database = session_data_in_database
78+
self.expiry = expiry
79+
self.custom_claims_in_access_token_payload = custom_claims_in_access_token_payload
80+
self.time_created = time_created
81+
self.tenant_id = tenant_id
8182

8283

8384
class ReqResInfo:
@@ -137,6 +138,7 @@ async def create_new_session(
137138
access_token_payload: Optional[Dict[str, Any]],
138139
session_data_in_database: Optional[Dict[str, Any]],
139140
disable_anti_csrf: Optional[bool],
141+
tenant_id: str,
140142
user_context: Dict[str, Any],
141143
) -> SessionContainer:
142144
pass
@@ -206,13 +208,13 @@ async def revoke_session(
206208

207209
@abstractmethod
208210
async def revoke_all_sessions_for_user(
209-
self, user_id: str, user_context: Dict[str, Any]
211+
self, user_id: str, tenant_id: str, revoke_across_all_tenants: Optional[bool], user_context: Dict[str, Any]
210212
) -> List[str]:
211213
pass
212214

213215
@abstractmethod
214216
async def get_all_session_handles_for_user(
215-
self, user_id: str, user_context: Dict[str, Any]
217+
self, user_id: str, tenant_id: str, fetch_across_all_tenants: Optional[bool], user_context: Dict[str, Any]
216218
) -> List[str]:
217219
pass
218220

@@ -383,6 +385,7 @@ def __init__(
383385
user_data_in_access_token: Optional[Dict[str, Any]],
384386
req_res_info: Optional[ReqResInfo],
385387
access_token_updated: bool,
388+
tenant_id: str,
386389
):
387390
self.recipe_implementation = recipe_implementation
388391
self.config = config
@@ -395,6 +398,7 @@ def __init__(
395398
self.user_data_in_access_token = user_data_in_access_token
396399
self.req_res_info: Optional[ReqResInfo] = req_res_info
397400
self.access_token_updated = access_token_updated
401+
self.tenant_id = tenant_id
398402

399403
self.response_mutators: List[ResponseMutator] = []
400404

@@ -436,6 +440,10 @@ async def merge_into_access_token_payload(
436440
def get_user_id(self, user_context: Optional[Dict[str, Any]] = None) -> str:
437441
pass
438442

443+
@abstractmethod
444+
def get_tenant_id(self, user_context: Optional[Dict[str, Any]] = None) -> str:
445+
pass
446+
439447
@abstractmethod
440448
def get_access_token_payload(
441449
self, user_context: Optional[Dict[str, Any]] = None

supertokens_python/recipe/session/recipe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def get_apis_handled(self) -> List[APIHandled]:
183183
async def handle_api_request(
184184
self,
185185
request_id: str,
186-
tenant_id: Optional[str],
186+
tenant_id: str,
187187
request: BaseRequest,
188188
path: NormalisedURLPath,
189189
method: str,

supertokens_python/recipe/session/recipe_implementation.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ async def create_new_session(
6464
access_token_payload: Optional[Dict[str, Any]],
6565
session_data_in_database: Optional[Dict[str, Any]],
6666
disable_anti_csrf: Optional[bool],
67+
tenant_id: str,
6768
user_context: Dict[str, Any],
6869
) -> SessionContainer:
6970
log_debug_message("createNewSession: Started")
@@ -74,6 +75,7 @@ async def create_new_session(
7475
disable_anti_csrf is True,
7576
access_token_payload,
7677
session_data_in_database,
78+
tenant_id
7779
)
7880
log_debug_message("createNewSession: Finished")
7981

@@ -95,6 +97,7 @@ async def create_new_session(
9597
payload,
9698
None,
9799
True,
100+
tenant_id
98101
)
99102

100103
return new_session
@@ -262,6 +265,7 @@ async def get_session(
262265
payload,
263266
None,
264267
access_token_updated,
268+
response.session.tenant_id,
265269
)
266270

267271
return session
@@ -312,6 +316,7 @@ async def refresh_session(
312316
user_data_in_access_token=payload,
313317
req_res_info=None,
314318
access_token_updated=True,
319+
tenant_id=payload["tId"],
315320
)
316321

317322
return session
@@ -322,14 +327,14 @@ async def revoke_session(
322327
return await session_functions.revoke_session(self, session_handle)
323328

324329
async def revoke_all_sessions_for_user(
325-
self, user_id: str, user_context: Dict[str, Any]
330+
self, user_id: str, tenant_id: Optional[str], revoke_across_all_tenants: Optional[bool], user_context: Dict[str, Any]
326331
) -> List[str]:
327-
return await session_functions.revoke_all_sessions_for_user(self, user_id)
332+
return await session_functions.revoke_all_sessions_for_user(self, user_id, tenant_id, revoke_across_all_tenants)
328333

329334
async def get_all_session_handles_for_user(
330-
self, user_id: str, user_context: Dict[str, Any]
335+
self, user_id: str, tenant_id: Optional[str], fetch_across_all_tenants: Optional[bool], user_context: Dict[str, Any]
331336
) -> List[str]:
332-
return await session_functions.get_all_session_handles_for_user(self, user_id)
337+
return await session_functions.get_all_session_handles_for_user(self, user_id, tenant_id, fetch_across_all_tenants)
333338

334339
async def revoke_multiple_sessions(
335340
self, session_handles: List[str], user_context: Dict[str, Any]
@@ -383,9 +388,8 @@ async def fetch_and_set_claim(
383388
if session_info is None:
384389
return False
385390

386-
# TODO: Pass tenant id
387391
access_token_payload_update = await claim.build(
388-
session_info.user_id, "pass-tenant-id", user_context
392+
session_info.user_id, session_info.tenant_id, user_context
389393
)
390394
return await self.merge_into_access_token_payload(
391395
session_handle, access_token_payload_update, user_context
@@ -463,5 +467,6 @@ async def regenerate_access_token(
463467
response["session"]["handle"],
464468
response["session"]["userId"],
465469
response["session"]["userDataInJWT"],
470+
response["session"]["tenantId"]
466471
)
467472
return RegenerateAccessTokenOkResult(session, access_token_obj)

supertokens_python/recipe/session/session_class.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
14-
from typing import Any, Dict, List, TypeVar, Union
14+
from typing import Any, Dict, List, Optional, TypeVar, Union
1515

1616
from supertokens_python.recipe.session.exceptions import (
1717
raise_invalid_claims_exception,
@@ -133,6 +133,9 @@ async def update_session_data_in_database(
133133
def get_user_id(self, user_context: Union[Dict[str, Any], None] = None) -> str:
134134
return self.user_id
135135

136+
def get_tenant_id(self, user_context: Union[Dict[str, Any], None] = None) -> str:
137+
return self.tenant_id
138+
136139
def get_access_token_payload(
137140
self, user_context: Union[Dict[str, Any], None] = None
138141
) -> Dict[str, Any]:

0 commit comments

Comments
 (0)