Skip to content

Commit d7c0183

Browse files
committed
feat: Improve types for session related functions
1 parent 9e013e9 commit d7c0183

File tree

7 files changed

+221
-286
lines changed

7 files changed

+221
-286
lines changed

supertokens_python/recipe/session/asyncio/__init__.py

Lines changed: 25 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# under the License.
1414
from typing import Any, Dict, List, Union, TypeVar, Callable, Optional
1515

16-
from supertokens_python.exceptions import SuperTokensError
1716
from supertokens_python.recipe.openid.interfaces import (
1817
GetOpenIdDiscoveryConfigurationResult,
1918
)
@@ -27,15 +26,6 @@
2726
ClaimsValidationResult,
2827
JSONObject,
2928
GetClaimValueOkResult,
30-
GetSessionUnauthorizedErrorResult,
31-
GetSessionTryRefreshTokenErrorResult,
32-
GetSessionClaimValidationErrorResult,
33-
GetSessionClaimValidationErrorResponseObject,
34-
CreateNewSessionResult,
35-
GetSessionOkResult,
36-
RefreshSessionOkResult,
37-
RefreshSessionUnauthorizedResult,
38-
RefreshSessionTokenTheftErrorResult,
3929
)
4030
from supertokens_python.recipe.session.recipe import (
4131
SessionRecipe,
@@ -47,7 +37,6 @@
4737
)
4838
from supertokens_python.types import MaybeAwaitable
4939
from supertokens_python.utils import FRAMEWORKS, resolve
50-
from ..exceptions import InvalidClaimsError
5140
from ..utils import get_required_claim_validators
5241
from ...jwt.interfaces import (
5342
CreateJwtOkResult,
@@ -94,7 +83,7 @@ async def create_new_session_without_request_response(
9483
session_data_in_database: Union[Dict[str, Any], None] = None,
9584
disable_anti_csrf: bool = False,
9685
user_context: Union[None, Dict[str, Any]] = None,
97-
) -> CreateNewSessionResult:
86+
) -> SessionContainer:
9887
if user_context is None:
9988
user_context = {}
10089
if session_data_in_database is None:
@@ -294,6 +283,9 @@ async def get_session(
294283
if user_context is None:
295284
user_context = {}
296285

286+
if session_required is None:
287+
session_required = True
288+
297289
recipe_instance = SessionRecipe.get_instance()
298290
recipe_interface_impl = recipe_instance.recipe_implementation
299291
config = recipe_instance.config
@@ -314,6 +306,7 @@ async def get_session_without_request_response(
314306
access_token: str,
315307
anti_csrf_token: Optional[str] = None,
316308
anti_csrf_check: Optional[bool] = None,
309+
session_required: Optional[bool] = None,
317310
check_database: Optional[bool] = None,
318311
override_global_claim_validators: Optional[
319312
Callable[
@@ -322,12 +315,7 @@ async def get_session_without_request_response(
322315
]
323316
] = None,
324317
user_context: Union[None, Dict[str, Any]] = None,
325-
) -> Union[
326-
GetSessionOkResult,
327-
GetSessionUnauthorizedErrorResult,
328-
GetSessionTryRefreshTokenErrorResult,
329-
GetSessionClaimValidationErrorResult,
330-
]:
318+
) -> Optional[SessionContainer]:
331319
"""Tries to validate an access token and build a Session object from it.
332320
333321
Notes about anti-csrf checking:
@@ -338,49 +326,44 @@ async def get_session_without_request_response(
338326
Args:
339327
- access_token: The access token extracted from the authorization header or cookies
340328
- anti_csrf_token: The anti-csrf token extracted from the authorization header or cookies. Can be undefined if antiCsrfCheck is false
341-
- anti_csrf_check: If true, anti-csrf checking will be done. If false, it will be skipped. Defaults behaviour to check.
342-
- check_database: If true, the session will be checked in the database. If false, it will be skipped. Defaults behaviour to skip.
329+
- anti_csrf_check: If true, anti-csrf checking will be done. If false, it will be skipped. Default behaviour is to check.
330+
- session_required: If true, throws an error if the session does not exist. Default is True.
331+
- check_database: If true, the session will be checked in the database. If false, it will be skipped. Default behaviour is to skip.
343332
- override_global_claim_validators: Alter the
344333
- user_context: user context
345334
346-
Returned values:
347-
- GetSessionOkResult: The session was successfully validated, including claim validation
348-
- GetSessionClaimValidationErrorResult: While the access token is valid, one or more claim validators have failed. Our frontend SDKs expect a 403 response the contents matching the value returned from this function.
349-
- GetSessionTryRefreshTokenErrorResult: This means, that the access token structure was valid, but it didn't pass validation for some reason and the user should call the refresh API.
350-
- You can send a 401 response to trigger this behaviour if you are using our frontend SDKs
351-
- GetSessionUnauthorizedErrorResult: This means that the access token likely doesn't belong to a SuperTokens session. If this is unexpected, it's best handled by sending a 401 response.
335+
Returned statuses:
336+
- OK: The session was successfully validated, including claim validation
337+
- CLAIM_VALIDATION_ERROR: While the access token is valid, one or more claim validators have failed. Our frontend SDKs expect a 403 response the contents matching the value returned from this function.
338+
- TRY_REFRESH_TOKEN_ERROR: This means, that the access token structure was valid, but it didn't pass validation for some reason and the user should call the refresh API.
339+
You can send a 401 response to trigger this behaviour if you are using our frontend SDKs
340+
- UNAUTHORISED: This means that the access token likely doesn't belong to a SuperTokens session. If this is unexpected, it's best handled by sending a 401 response.
352341
"""
353342
if user_context is None:
354343
user_context = {}
355344

345+
if session_required is None:
346+
session_required = True
347+
356348
recipe_interface_impl = SessionRecipe.get_instance().recipe_implementation
357349

358-
res = await recipe_interface_impl.get_session(
350+
session_ = await recipe_interface_impl.get_session(
359351
access_token,
360352
anti_csrf_token,
361353
anti_csrf_check,
354+
session_required,
362355
check_database,
363356
override_global_claim_validators,
364357
user_context,
365358
)
366359

367-
if isinstance(res, GetSessionOkResult):
360+
if session_ is not None:
368361
claim_validators = await get_required_claim_validators(
369-
res.session, override_global_claim_validators, user_context
362+
session_, override_global_claim_validators, user_context
370363
)
371-
try:
372-
await res.session.assert_claims(claim_validators, user_context)
373-
except SuperTokensError as e:
374-
if isinstance(e, InvalidClaimsError):
375-
return GetSessionClaimValidationErrorResult(
376-
error=e,
377-
response=GetSessionClaimValidationErrorResponseObject(
378-
message="invalid claim", claim_validation_errors=e.payload
379-
),
380-
)
381-
raise e
364+
await session_.assert_claims(claim_validators, user_context)
382365

383-
return res
366+
return session_
384367

385368

386369
async def refresh_session(
@@ -412,11 +395,7 @@ async def refresh_session_without_request_response(
412395
disable_anti_csrf: bool = False,
413396
anti_csrf_token: Optional[str] = None,
414397
user_context: Optional[Dict[str, Any]] = None,
415-
) -> Union[
416-
RefreshSessionOkResult,
417-
RefreshSessionUnauthorizedResult,
418-
RefreshSessionTokenTheftErrorResult,
419-
]:
398+
) -> SessionContainer:
420399
if user_context is None:
421400
user_context = {}
422401

supertokens_python/recipe/session/interfaces.py

Lines changed: 18 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,13 @@
2525
TypeVar,
2626
Union,
2727
)
28+
from typing_extensions import TypedDict
2829

2930
from supertokens_python.async_to_sync_wrapper import sync
3031
from supertokens_python.types import APIResponse, GeneralErrorResponse, MaybeAwaitable
3132

3233
from ...utils import resolve
33-
from .exceptions import (
34-
ClaimValidationError,
35-
UnauthorisedError,
36-
TokenTheftError,
37-
TryRefreshTokenError,
38-
InvalidClaimsError,
39-
)
34+
from .exceptions import ClaimValidationError
4035
from .utils import SessionConfig, TokenTransferMethod
4136

4237
if TYPE_CHECKING:
@@ -93,75 +88,6 @@ def __init__(
9388
self.transfer_method = transfer_method
9489

9590

96-
class CreateNewSessionResult:
97-
status = "OK"
98-
99-
def __init__(self, session: SessionContainer):
100-
self.session = session
101-
102-
103-
class GetSessionOkResult:
104-
status = "OK"
105-
106-
def __init__(self, session: SessionContainer):
107-
self.session = session
108-
109-
110-
class GetSessionUnauthorizedErrorResult:
111-
status = "UNAUTHORISED"
112-
113-
def __init__(self, error: Exception):
114-
self.error = error
115-
116-
117-
class GetSessionTryRefreshTokenErrorResult:
118-
status = "TRY_REFRESH_TOKEN_ERROR"
119-
120-
def __init__(self, error: TryRefreshTokenError):
121-
self.error = error
122-
123-
124-
class GetSessionClaimValidationErrorResponseObject:
125-
def __init__(
126-
self, message: str, claim_validation_errors: List[ClaimValidationError]
127-
):
128-
self.message = message
129-
self.claim_validation_errors = claim_validation_errors
130-
131-
132-
class GetSessionClaimValidationErrorResult:
133-
status = "CLAIM_VALIDATION_ERROR"
134-
135-
def __init__(
136-
self,
137-
error: InvalidClaimsError,
138-
response: GetSessionClaimValidationErrorResponseObject,
139-
):
140-
self.error = error
141-
self.response = response
142-
143-
144-
class RefreshSessionOkResult:
145-
status = "OK"
146-
147-
def __init__(self, session: SessionContainer):
148-
self.session = session
149-
150-
151-
class RefreshSessionUnauthorizedResult:
152-
status = "UNAUTHORISED"
153-
154-
def __init__(self, error: UnauthorisedError):
155-
self.error = error
156-
157-
158-
class RefreshSessionTokenTheftErrorResult:
159-
status = "TOKEN_THEFT_ERROR"
160-
161-
def __init__(self, error: TokenTheftError):
162-
self.error = error
163-
164-
16591
_T = TypeVar("_T")
16692
JSONObject = Dict[str, Any]
16793

@@ -190,6 +116,14 @@ def __init__(
190116
self.access_token_payload_update = access_token_payload_update
191117

192118

119+
class GetSessionTokensDangerouslyDict(TypedDict):
120+
accessToken: str
121+
accessAndFrontTokenUpdated: bool
122+
refreshToken: Optional[str]
123+
frontToken: str
124+
antiCsrfToken: Optional[str]
125+
126+
193127
class RecipeInterface(ABC): # pylint: disable=too-many-public-methods
194128
def __init__(self):
195129
pass
@@ -202,7 +136,7 @@ async def create_new_session(
202136
session_data_in_database: Optional[Dict[str, Any]],
203137
disable_anti_csrf: Optional[bool],
204138
user_context: Dict[str, Any],
205-
) -> CreateNewSessionResult:
139+
) -> SessionContainer:
206140
pass
207141

208142
@abstractmethod
@@ -220,6 +154,7 @@ async def get_session(
220154
access_token: str,
221155
anti_csrf_token: Optional[str],
222156
anti_csrf_check: Optional[bool] = None,
157+
session_required: Optional[bool] = None,
223158
check_database: Optional[bool] = None,
224159
override_global_claim_validators: Optional[
225160
Callable[
@@ -228,11 +163,7 @@ async def get_session(
228163
]
229164
] = None,
230165
user_context: Optional[Dict[str, Any]] = None,
231-
) -> Union[
232-
GetSessionOkResult,
233-
GetSessionUnauthorizedErrorResult,
234-
GetSessionTryRefreshTokenErrorResult,
235-
]:
166+
) -> Optional[SessionContainer]:
236167
pass
237168

238169
@abstractmethod
@@ -262,11 +193,7 @@ async def refresh_session(
262193
anti_csrf_token: Optional[str],
263194
disable_anti_csrf: bool,
264195
user_context: Dict[str, Any],
265-
) -> Union[
266-
RefreshSessionOkResult,
267-
RefreshSessionUnauthorizedResult,
268-
RefreshSessionTokenTheftErrorResult,
269-
]:
196+
) -> SessionContainer:
270197
pass
271198

272199
@abstractmethod
@@ -516,6 +443,10 @@ def get_access_token_payload(
516443
def get_handle(self, user_context: Optional[Dict[str, Any]] = None) -> str:
517444
pass
518445

446+
@abstractmethod
447+
def get_all_session_tokens_dangerously(self) -> GetSessionTokensDangerouslyDict:
448+
pass
449+
519450
@abstractmethod
520451
def get_access_token(self, user_context: Optional[Dict[str, Any]] = None) -> str:
521452
pass

0 commit comments

Comments
 (0)