Skip to content

Commit ead3bd8

Browse files
committed
fix: Add types for session_functions.py API call responses
1 parent d7c0183 commit ead3bd8

File tree

2 files changed

+130
-54
lines changed

2 files changed

+130
-54
lines changed

supertokens_python/recipe/session/recipe_implementation.py

Lines changed: 23 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
SessionDoesNotExistError,
3838
SessionInformationResult,
3939
SessionObj,
40-
TokenInfo,
4140
)
4241
from .jwks import JWKClient
4342
from .jwt import ParsedJWTInfo, parse_jwt_without_signature_verification
@@ -93,27 +92,20 @@ async def create_new_session(
9392
log_debug_message("createNewSession: Finished")
9493

9594
payload = parse_jwt_without_signature_verification(
96-
result["accessToken"]["token"]
95+
result.accessToken.token
9796
).payload
9897

99-
refresh_token = result["refreshToken"]
100-
refresh_token_info = TokenInfo(
101-
refresh_token["token"],
102-
refresh_token["expiry"],
103-
refresh_token["createdTime"],
104-
)
105-
10698
new_session = Session(
10799
self,
108100
self.config,
109-
result["accessToken"]["token"],
101+
result.accessToken.token,
110102
build_front_token(
111-
result["session"]["userId"], result["accessToken"]["expiry"], payload
103+
result.session.userId, result.accessToken.expiry, payload
112104
),
113-
refresh_token_info,
114-
result.get("antiCsrfToken"),
115-
result["session"]["handle"],
116-
result["session"]["userId"],
105+
result.refreshToken,
106+
result.antiCsrfToken,
107+
result.session.handle,
108+
result.session.userId,
117109
payload,
118110
None,
119111
True,
@@ -231,28 +223,28 @@ async def get_session(
231223

232224
log_debug_message("getSession: Success!")
233225

234-
if "accessToken" in response:
226+
if response.accessToken is not None:
235227
payload = parse_jwt_without_signature_verification(
236-
response["accessToken"]["token"]
228+
response.accessToken.token
237229
).payload
238-
access_token_str = response["accessToken"]["token"]
239-
expiry_time = response["accessToken"]["expiry"]
230+
access_token_str = response.accessToken.token
231+
expiry_time = response.accessToken.expiry
240232
access_token_updated = True
241233
else:
242234
payload = access_token_obj.payload
243235
access_token_str = access_token
244-
expiry_time = response["session"]["expiryTime"]
236+
expiry_time = response.session.expiryTime
245237
access_token_updated = False
246238

247239
session = Session(
248240
self,
249241
self.config,
250242
access_token_str,
251-
build_front_token(response["session"]["userId"], expiry_time, payload),
243+
build_front_token(response.session.userId, expiry_time, payload),
252244
None, # refresh_token
253245
anti_csrf_token,
254-
response["session"]["handle"],
255-
response["session"]["userId"],
246+
response.session.handle,
247+
response.session.userId,
256248
payload,
257249
None,
258250
access_token_updated,
@@ -287,29 +279,22 @@ async def refresh_session(
287279
log_debug_message("refreshSession: Success!")
288280

289281
payload = parse_jwt_without_signature_verification(
290-
response["accessToken"]["token"]
282+
response.accessToken.token,
291283
).payload
292284

293-
new_refresh_token: Dict[str, Any] = response["refreshToken"]
294-
new_refresh_token_info = TokenInfo(
295-
new_refresh_token["token"],
296-
new_refresh_token["expiry"],
297-
new_refresh_token["createdTime"],
298-
)
299-
300285
session = Session(
301286
self,
302287
self.config,
303-
response["accessToken"]["token"],
288+
response.accessToken.token,
304289
build_front_token(
305-
response["session"]["userId"],
306-
response["accessToken"]["expiry"],
290+
response.session.userId,
291+
response.accessToken.expiry,
307292
payload,
308293
),
309-
new_refresh_token_info,
310-
response.get("antiCsrfToken"),
311-
response["session"]["handle"],
312-
response["session"]["userId"],
294+
response.refreshToken,
295+
response.antiCsrfToken,
296+
response.session.handle,
297+
response.session.userId,
313298
user_data_in_access_token=payload,
314299
req_res_info=None,
315300
access_token_updated=True,

supertokens_python/recipe/session/session_functions.py

Lines changed: 107 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
import time
17-
from typing import TYPE_CHECKING, Any, Dict, List, Union
17+
from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional
1818

1919
from supertokens_python.recipe.session.interfaces import SessionInformationResult
2020

@@ -28,6 +28,7 @@
2828
from supertokens_python.logger import log_debug_message
2929
from supertokens_python.normalised_url_path import NormalisedURLPath
3030
from supertokens_python.process_state import AllowedProcessStates, ProcessState
31+
from supertokens_python.recipe.session.interfaces import TokenInfo
3132

3233
from .exceptions import (
3334
TryRefreshTokenError,
@@ -37,13 +38,61 @@
3738
)
3839

3940

41+
class CreateOrRefreshAPIResponseSession:
42+
def __init__(self, handle: str, userId: str, userDataInJWT: Any):
43+
self.handle = handle
44+
self.userId = userId
45+
self.userDataInJWT = userDataInJWT
46+
47+
48+
class CreateOrRefreshAPIResponse:
49+
def __init__(
50+
self,
51+
session: CreateOrRefreshAPIResponseSession,
52+
accessToken: TokenInfo,
53+
refreshToken: TokenInfo,
54+
antiCsrfToken: Optional[str],
55+
):
56+
self.session = session
57+
self.accessToken = accessToken
58+
self.refreshToken = refreshToken
59+
self.antiCsrfToken = antiCsrfToken
60+
61+
62+
class GetSessionAPIResponseSession:
63+
def __init__(
64+
self, handle: str, userId: str, userDataInJWT: Dict[str, Any], expiryTime: int
65+
) -> None:
66+
self.handle = handle
67+
self.userId = userId
68+
self.userDataInJWT = userDataInJWT
69+
self.expiryTime = expiryTime
70+
71+
72+
class GetSessionAPIResponseAccessToken:
73+
def __init__(self, token: str, expiry: int, createdTime: int) -> None:
74+
self.token = token
75+
self.expiry = expiry
76+
self.createdTime = createdTime
77+
78+
79+
class GetSessionAPIResponse:
80+
def __init__(
81+
self,
82+
session: GetSessionAPIResponseSession,
83+
accessToken: Optional[GetSessionAPIResponseAccessToken] = None,
84+
) -> None:
85+
self.session = session
86+
self.accessToken = accessToken
87+
88+
4089
async def create_new_session(
4190
recipe_implementation: RecipeImplementation,
4291
user_id: str,
4392
disable_anti_csrf: bool,
4493
access_token_payload: Union[None, Dict[str, Any]],
4594
session_data_in_database: Union[None, Dict[str, Any]],
46-
) -> Dict[str, Any]:
95+
) -> CreateOrRefreshAPIResponse:
4796
if session_data_in_database is None:
4897
session_data_in_database = {}
4998
if access_token_payload is None:
@@ -66,7 +115,22 @@ async def create_new_session(
66115

67116
response.pop("status", None)
68117

69-
return response # FIXME: type the response
118+
return CreateOrRefreshAPIResponse(
119+
CreateOrRefreshAPIResponseSession(
120+
response["handle"], response["userId"], response["userDataInJWT"]
121+
),
122+
TokenInfo(
123+
response["accessToken"]["token"],
124+
response["accessToken"]["expiry"],
125+
response["accessToken"]["createdTime"],
126+
),
127+
TokenInfo(
128+
response["refreshToken"]["token"],
129+
response["refreshToken"]["expiry"],
130+
response["refreshToken"]["createdTime"],
131+
),
132+
response["antiCsrfToken"] if "antiCsrfToken" in response else None,
133+
)
70134

71135

72136
async def get_session(
@@ -75,9 +139,9 @@ async def get_session(
75139
anti_csrf_token: Union[str, None],
76140
do_anti_csrf_check: bool,
77141
always_check_core: bool,
78-
) -> Dict[str, Any]:
142+
) -> GetSessionAPIResponse:
79143
config = recipe_implementation.config
80-
access_token_info = None
144+
access_token_info: Optional[Dict[str, Any]] = None
81145

82146
try:
83147
access_token_info = get_info_from_access_token(
@@ -179,14 +243,14 @@ async def get_session(
179243
and not always_check_core
180244
and access_token_info["parentRefreshTokenHash1"] is None
181245
):
182-
return {
183-
"session": {
184-
"handle": access_token_info["sessionHandle"],
185-
"userId": access_token_info["userId"],
186-
"userDataInJWT": access_token_info["userData"],
187-
"expiryTime": access_token_info["expiryTime"],
188-
}
189-
}
246+
return GetSessionAPIResponse(
247+
GetSessionAPIResponseSession(
248+
access_token_info["sessionHandle"],
249+
access_token_info["userId"],
250+
access_token_info["userData"],
251+
access_token_info["expiryTime"],
252+
)
253+
)
190254

191255
ProcessState.get_instance().add_state(
192256
AllowedProcessStates.CALLING_SERVICE_IN_VERIFY
@@ -206,7 +270,19 @@ async def get_session(
206270
)
207271
if response["status"] == "OK":
208272
response.pop("status", None)
209-
return response # FIXME: type the response
273+
return GetSessionAPIResponse(
274+
GetSessionAPIResponseSession(
275+
response["session"]["handle"],
276+
response["session"]["userId"],
277+
response["session"]["userData"],
278+
response["session"]["expiresAt"],
279+
),
280+
GetSessionAPIResponseAccessToken(
281+
response["accessToken"]["token"],
282+
response["accessToken"]["expiry"],
283+
response["accessToken"]["createdTime"],
284+
),
285+
)
210286
if response["status"] == "UNAUTHORISED":
211287
log_debug_message("getSession: Returning UNAUTHORISED because of core response")
212288
raise_unauthorised_exception(response["message"])
@@ -222,7 +298,7 @@ async def refresh_session(
222298
refresh_token: str,
223299
anti_csrf_token: Union[str, None],
224300
disable_anti_csrf: bool,
225-
) -> Dict[str, Any]:
301+
) -> CreateOrRefreshAPIResponse:
226302
data = {
227303
"refreshToken": refresh_token,
228304
"enableAntiCsrf": (
@@ -249,7 +325,22 @@ async def refresh_session(
249325
)
250326
if response["status"] == "OK":
251327
response.pop("status", None)
252-
return response # FIXME: type the response
328+
return CreateOrRefreshAPIResponse(
329+
CreateOrRefreshAPIResponseSession(
330+
response["handle"], response["userId"], response["userDataInJWT"]
331+
),
332+
TokenInfo(
333+
response["accessToken"]["token"],
334+
response["accessToken"]["expiry"],
335+
response["accessToken"]["createdTime"],
336+
),
337+
TokenInfo(
338+
response["refreshToken"]["token"],
339+
response["refreshToken"]["expiry"],
340+
response["refreshToken"]["createdTime"],
341+
),
342+
response["antiCsrfToken"] if "antiCsrfToken" in response else None,
343+
)
253344
if response["status"] == "UNAUTHORISED":
254345
log_debug_message(
255346
"refreshSession: Returning UNAUTHORISED because of core response"

0 commit comments

Comments
 (0)