Skip to content

Commit 1f161ef

Browse files
committed
fix: Clear session changes for edge cases
1 parent ed905b8 commit 1f161ef

File tree

9 files changed

+172
-22
lines changed

9 files changed

+172
-22
lines changed

supertokens_python/framework/django/django_response.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import json
1515
from datetime import datetime
1616
from math import ceil
17-
from typing import Any, Dict, Union
17+
from typing import Any, Dict, Optional
1818

1919
from supertokens_python.framework.response import BaseResponse
2020

@@ -42,7 +42,7 @@ def set_cookie(
4242
value: str,
4343
expires: int,
4444
path: str = "/",
45-
domain: Union[str, None] = None,
45+
domain: Optional[str] = None,
4646
secure: bool = False,
4747
httponly: bool = False,
4848
samesite: str = "lax",
@@ -69,11 +69,14 @@ def set_status_code(self, status_code: int):
6969
def set_header(self, key: str, value: str):
7070
self.response[key] = value
7171

72-
def get_header(self, key: str):
72+
def get_header(self, key: str) -> Optional[str]:
7373
if self.response.has_header(key):
7474
return self.response[key]
7575
return None
7676

77+
def remove_header(self, key: str):
78+
del self.response[key]
79+
7780
def set_json_content(self, content: Dict[str, Any]):
7881
if not self.response_sent:
7982
self.set_header("Content-Type", "application/json; charset=utf-8")

supertokens_python/framework/fastapi/fastapi_response.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import json
1515
from math import ceil
1616
from time import time
17-
from typing import Any, Dict, Union
17+
from typing import Any, Dict, Optional
1818

1919
from supertokens_python.framework.response import BaseResponse
2020

@@ -44,7 +44,7 @@ def set_cookie(
4444
value: str,
4545
expires: int,
4646
path: str = "/",
47-
domain: Union[str, None] = None,
47+
domain: Optional[str] = None,
4848
secure: bool = False,
4949
httponly: bool = False,
5050
samesite: str = "lax",
@@ -78,9 +78,12 @@ def set_cookie(
7878
def set_header(self, key: str, value: str):
7979
self.response.headers[key] = value
8080

81-
def get_header(self, key: str) -> Union[str, None]:
81+
def get_header(self, key: str) -> Optional[str]:
8282
return self.response.headers.get(key, None)
8383

84+
def remove_header(self, key: str):
85+
del self.response.headers[key]
86+
8487
def set_status_code(self, status_code: int):
8588
if not self.status_set:
8689
self.response.status_code = status_code

supertokens_python/framework/flask/flask_response.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414
import json
15-
from typing import Any, Dict, List, Union
15+
from typing import Any, Dict, List, Optional
1616

1717
from supertokens_python.framework.response import BaseResponse
1818

@@ -40,7 +40,7 @@ def set_cookie(
4040
value: str,
4141
expires: int,
4242
path: str = "/",
43-
domain: Union[str, None] = None,
43+
domain: Optional[str] = None,
4444
secure: bool = False,
4545
httponly: bool = False,
4646
samesite: str = "lax",
@@ -59,9 +59,12 @@ def set_cookie(
5959
def set_header(self, key: str, value: str):
6060
self.response.headers.set(key, value)
6161

62-
def get_header(self, key: str) -> Union[None, str]:
62+
def get_header(self, key: str) -> Optional[str]:
6363
return self.response.headers.get(key)
6464

65+
def remove_header(self, key: str):
66+
del self.response.headers[key]
67+
6568
def set_status_code(self, status_code: int):
6669
if not self.status_set:
6770
self.response.status_code = status_code

supertokens_python/framework/response.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# under the License.
1414

1515
from abc import ABC, abstractmethod
16-
from typing import Any, Dict, Union
16+
from typing import Any, Dict, Optional
1717

1818

1919
class BaseResponse(ABC):
@@ -31,19 +31,23 @@ def set_cookie(
3131
# max_age: Union[int, None] = None,
3232
expires: int,
3333
path: str = "/",
34-
domain: Union[str, None] = None,
34+
domain: Optional[str] = None,
3535
secure: bool = False,
3636
httponly: bool = False,
3737
samesite: str = "lax",
3838
):
3939
pass
4040

4141
@abstractmethod
42-
def set_header(self, key: str, value: str):
42+
def set_header(self, key: str, value: str) -> None:
4343
pass
4444

4545
@abstractmethod
46-
def get_header(self, key: str) -> Union[str, None]:
46+
def get_header(self, key: str) -> Optional[str]:
47+
pass
48+
49+
@abstractmethod
50+
def remove_header(self, key: str) -> None:
4751
pass
4852

4953
@abstractmethod

supertokens_python/recipe/session/cookie_and_header.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ def set_header(response: BaseResponse, key: str, value: str, allow_duplicate: bo
9797
response.set_header(key, value)
9898

9999

100+
def remove_header(response: BaseResponse, key: str):
101+
if response.get_header(key) is not None:
102+
response.remove_header(key)
103+
104+
100105
def get_cookie(request: BaseRequest, key: str):
101106
cookie_val = request.get_cookie(key)
102107
if cookie_val is None:
@@ -171,23 +176,31 @@ def get_rid_header(request: BaseRequest):
171176

172177

173178
def clear_session_from_all_token_transfer_methods(
174-
response: BaseResponse, recipe: SessionRecipe, request: BaseRequest
179+
response: BaseResponse, recipe: SessionRecipe
175180
):
181+
# We are clearing the session in all transfermethods to be sure to override cookies in case they have been already added to the response.
182+
# This is done to handle the following use-case:
183+
# If the app overrides signInPOST to check the ban status of the user after the original implementation and throwing an UNAUTHORISED error
184+
# In this case: the SDK has attached cookies to the response, but none was sent with the request
185+
# We can't know which to clear since we can't reliably query or remove the set-cookie header added to the response (causes issues in some frameworks, i.e.: hapi)
186+
# The safe solution in this case is to overwrite all the response cookies/headers with an empty value, which is what we are doing here.
176187
for transfer_method in available_token_transfer_methods:
177-
if get_token(request, "access", transfer_method) is not None:
178-
_clear_session(response, recipe.config, transfer_method)
188+
_clear_session(response, recipe.config, transfer_method)
179189

180190

181191
def _clear_session(
182192
response: BaseResponse,
183193
config: SessionConfig,
184194
transfer_method: TokenTransferMethod,
185195
):
186-
# If we can tell it's a cookie based session we are not clearing using headers
196+
# If we can be specific about which transferMethod we want to clear, there is no reason to clear the other ones
187197
token_types: List[TokenType] = ["access", "refresh"]
188198
for token_type in token_types:
189199
_set_token(response, config, token_type, "", 0, transfer_method)
190200

201+
remove_header(
202+
response, ANTI_CSRF_HEADER_KEY
203+
) # This can be added multiple times in some cases, but that should be OK
191204
set_header(response, FRONT_TOKEN_HEADER_SET_KEY, "remove", False)
192205
set_header(
193206
response, ACCESS_CONTROL_EXPOSE_HEADERS, FRONT_TOKEN_HEADER_SET_KEY, True

supertokens_python/recipe/session/interfaces.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
from .utils import SessionConfig, TokenTransferMethod
3535

3636
if TYPE_CHECKING:
37-
from supertokens_python.framework import BaseRequest, BaseResponse
37+
from supertokens_python.framework import BaseRequest
38+
39+
from supertokens_python.framework import BaseResponse
3840

3941

4042
class SessionObj:

supertokens_python/recipe/session/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ async def on_token_theft_detected(
138138
response: BaseResponse,
139139
) -> BaseResponse:
140140
log_debug_message("Clearing tokens because of TOKEN_THEFT_DETECTED response")
141-
clear_session_from_all_token_transfer_methods(response, recipe, request)
141+
clear_session_from_all_token_transfer_methods(response, recipe)
142142
return await resolve(
143143
self.__on_token_theft_detected(request, session_handle, user_id, response)
144144
)
@@ -159,7 +159,7 @@ async def on_unauthorised(
159159
):
160160
if do_clear_cookies:
161161
log_debug_message("Clearing tokens because of UNAUTHORISED response")
162-
clear_session_from_all_token_transfer_methods(response, recipe, request)
162+
clear_session_from_all_token_transfer_methods(response, recipe)
163163
return await resolve(self.__on_unauthorised(request, message, response))
164164

165165
async def on_invalid_claim(

tests/Fastapi/test_fastapi.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,20 @@
3131
get_session,
3232
refresh_session,
3333
)
34+
from supertokens_python.recipe.session.exceptions import UnauthorisedError
3435
from supertokens_python.recipe.session.framework.fastapi import verify_session
3536
from supertokens_python.recipe.session.interfaces import APIInterface
37+
from supertokens_python.recipe.session.interfaces import APIOptions as SessionAPIOptions
3638
from tests.utils import (
3739
TEST_DRIVER_CONFIG_ACCESS_TOKEN_PATH,
3840
TEST_DRIVER_CONFIG_COOKIE_DOMAIN,
3941
TEST_DRIVER_CONFIG_COOKIE_SAME_SITE,
4042
TEST_DRIVER_CONFIG_REFRESH_TOKEN_PATH,
43+
assert_info_clears_tokens,
4144
clean_st,
4245
extract_all_cookies,
46+
extract_info,
47+
get_st_init_args,
4348
reset,
4449
setup_st,
4550
start_st,
@@ -107,6 +112,16 @@ async def custom_logout(request: Request): # type: ignore
107112
await session.revoke_session()
108113
return {} # type: ignore
109114

115+
@app.post("/create")
116+
async def _create(request: Request): # type: ignore
117+
await create_new_session(request, "userId", {}, {})
118+
return ""
119+
120+
@app.post("/create-throw")
121+
async def _create_throw(request: Request): # type: ignore
122+
await create_new_session(request, "userId", {}, {})
123+
raise UnauthorisedError("unauthorised")
124+
110125
return TestClient(app)
111126

112127

@@ -524,3 +539,87 @@ def test_fastapi_root_path(fastapi_root_path: str):
524539
# The API should migrate (and return 404 here)
525540
response = test_client.get("/auth/signup/email/[email protected]")
526541
assert response.status_code == 404
542+
543+
544+
@mark.asyncio
545+
@mark.parametrize("token_transfer_method", ["cookie", "header"])
546+
async def test_should_clear_all_response_during_refresh_if_unauthorized(
547+
driver_config_client: TestClient, token_transfer_method: str
548+
):
549+
def override_session_apis(oi: APIInterface):
550+
oi_refresh_post = oi.refresh_post
551+
552+
async def refresh_post(
553+
api_options: SessionAPIOptions, user_context: Dict[str, Any]
554+
):
555+
await oi_refresh_post(api_options, user_context)
556+
raise UnauthorisedError("unauthorized", clear_tokens=True)
557+
558+
oi.refresh_post = refresh_post
559+
return oi
560+
561+
init(
562+
**get_st_init_args(
563+
[
564+
session.init(
565+
anti_csrf="VIA_TOKEN",
566+
override=session.InputOverrideConfig(apis=override_session_apis),
567+
)
568+
]
569+
)
570+
) # type: ignore
571+
start_st()
572+
573+
res = driver_config_client.post(
574+
"/create", headers={"st-auth-mode": token_transfer_method}
575+
)
576+
info = extract_info(res)
577+
578+
assert info["accessTokenFromAny"] is not None
579+
assert info["refreshTokenFromAny"] is not None
580+
581+
headers: Dict[str, Any] = {}
582+
cookies: Dict[str, Any] = {}
583+
584+
if token_transfer_method == "header":
585+
headers.update({"authorization": f"Bearer {info['refreshTokenFromAny']}"})
586+
else:
587+
cookies.update(
588+
{"sRefreshToken": info["refreshTokenFromAny"], "sIdRefreshToken": "asdf"}
589+
)
590+
591+
if info["antiCsrf"] is not None:
592+
headers.update({"anti-csrf": info["antiCsrf"]})
593+
594+
res = driver_config_client.post(
595+
"/auth/session/refresh", headers=headers, cookies=cookies
596+
)
597+
info = extract_info(res)
598+
599+
assert res.status_code == 401
600+
assert_info_clears_tokens(info, token_transfer_method)
601+
602+
603+
@mark.asyncio
604+
@mark.parametrize("token_transfer_method", ["cookie", "header"])
605+
async def test_revoking_session_after_create_new_session_with_throwing_unauthorized_error(
606+
driver_config_client: TestClient, token_transfer_method: str
607+
):
608+
init(
609+
**get_st_init_args(
610+
[
611+
session.init(
612+
anti_csrf="VIA_TOKEN",
613+
)
614+
]
615+
)
616+
) # type: ignore
617+
start_st()
618+
619+
res = driver_config_client.post(
620+
"/create-throw", headers={"st-auth-mode": token_transfer_method}
621+
)
622+
info = extract_info(res)
623+
624+
assert res.status_code == 401
625+
assert_info_clears_tokens(info, token_transfer_method)

tests/utils.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,9 @@ def extract_info(response: Response) -> Dict[str, Any]:
239239
access_token = cookies.get("sAccessToken", {}).get("value")
240240
refresh_token = cookies.get("sRefreshToken", {}).get("value")
241241

242+
access_token_from_header = response.headers.get("st-access-token")
243+
refresh_token_from_header = response.headers.get("st-refresh-token")
244+
242245
return {
243246
**cookies,
244247
"accessToken": access_token,
@@ -247,11 +250,31 @@ def extract_info(response: Response) -> Dict[str, Any]:
247250
"status_code": response.status_code,
248251
"body": response.json(),
249252
"antiCsrf": response.headers.get("anti-csrf"),
250-
"accessTokenFromHeader": response.headers.get("st-access-token"),
251-
"refreshTokenFromHeader": response.headers.get("st-refresh-token"),
253+
"accessTokenFromHeader": access_token_from_header,
254+
"refreshTokenFromHeader": refresh_token_from_header,
255+
"accessTokenFromAny": access_token_from_header or access_token,
256+
"refreshTokenFromAny": refresh_token_from_header or refresh_token,
252257
}
253258

254259

260+
def assert_info_clears_tokens(info: Dict[str, Any], token_transfer_method: str):
261+
if token_transfer_method == "cookie":
262+
assert info["accessToken"] == ""
263+
assert info["refreshToken"] == ""
264+
assert info["sAccessToken"]["expires"] == "Thu, 01 Jan 1970 00:00:00 GMT"
265+
assert info["sRefreshToken"]["expires"] == "Thu, 01 Jan 1970 00:00:00 GMT"
266+
assert info["sAccessToken"]["domain"] == ""
267+
assert info["sRefreshToken"]["domain"] == ""
268+
elif token_transfer_method == "header":
269+
assert info["accessTokenFromHeader"] == ""
270+
assert info["refreshTokenFromHeader"] == ""
271+
else:
272+
raise Exception("unknown token transfer method: " + token_transfer_method)
273+
274+
assert info["frontToken"] == "remove"
275+
assert info["antiCsrf"] is None
276+
277+
255278
def get_unix_timestamp(expiry: str):
256279
return int(
257280
datetime.strptime(expiry, "%a, %d %b %Y %H:%M:%S GMT")

0 commit comments

Comments
 (0)