13
13
# under the License.
14
14
15
15
from datetime import datetime , timedelta
16
- from typing import Any , Dict , List
16
+ from typing import Any , Dict , List , Optional
17
17
from unittest .mock import MagicMock
18
18
19
- from fastapi import FastAPI
19
+ from fastapi import FastAPI , Depends
20
20
from fastapi .requests import Request
21
+ from fastapi .responses import JSONResponse
21
22
from fastapi .testclient import TestClient
22
23
from pytest import fixture , mark
23
24
41
42
merge_into_access_token_payload ,
42
43
update_session_data_in_database ,
43
44
)
44
- from supertokens_python .recipe .session .interfaces import RecipeInterface
45
+ from supertokens_python .recipe .session .interfaces import (
46
+ RecipeInterface ,
47
+ SessionContainer ,
48
+ )
45
49
from supertokens_python .recipe .session .jwt import (
46
50
parse_jwt_without_signature_verification ,
47
51
)
52
56
refresh_session ,
53
57
revoke_session ,
54
58
)
59
+ from supertokens_python .recipe .session .framework .fastapi import verify_session
55
60
from tests .utils import clean_st , reset , setup_st , start_st
56
61
57
62
pytestmark = mark .asyncio
@@ -251,6 +256,13 @@ async def create_api(request: Request): # type: ignore
251
256
await async_create_new_session (request , "test-user" , {}, {})
252
257
return ""
253
258
259
+ @app .post ("/sessioninfo-optional" )
260
+ async def _session_info (s : Optional [SessionContainer ] = Depends (verify_session (session_required = False ))): # type: ignore
261
+ if s is not None :
262
+ return JSONResponse ({"session" : s .get_handle (), "user_id" : s .get_user_id ()})
263
+ else :
264
+ return JSONResponse ({"message" : "no session" })
265
+
254
266
return TestClient (app )
255
267
256
268
@@ -707,7 +719,7 @@ async def test_that_verify_session_doesnt_always_call_core():
707
719
assert session3 .refresh_token is not None
708
720
709
721
assert (
710
- AllowedProcessStates .CALLING_SERVICE_IN_VERIFY
722
+ AllowedProcessStates .CALLING_SERVICE_IN_VERIFYG
711
723
not in ProcessState .get_instance ().history
712
724
)
713
725
@@ -724,3 +736,49 @@ async def test_that_verify_session_doesnt_always_call_core():
724
736
AllowedProcessStates .CALLING_SERVICE_IN_VERIFY
725
737
in ProcessState .get_instance ().history
726
738
) # Core got called this time
739
+
740
+
741
+ async def test_anti_csrf_header_via_custom_header_check_happens_only_when_access_token_is_provided (
742
+ driver_config_client : TestClient ,
743
+ ):
744
+ args = get_st_init_args ([session .init (anti_csrf = "VIA_CUSTOM_HEADER" , get_token_transfer_method = lambda * _ : "cookie" )]) # type: ignore
745
+ init (** args ) # type: ignore
746
+ start_st ()
747
+
748
+ response = driver_config_client .post ("/create" )
749
+ assert response .status_code == 200
750
+
751
+ # With access token:
752
+ # without RID:
753
+ response = driver_config_client .post ("/sessioninfo-optional" )
754
+ assert response .status_code == 401
755
+ assert response .json () == {"message" : "try refresh token" }
756
+
757
+ # with RID:
758
+ response = driver_config_client .post (
759
+ "/sessioninfo-optional" ,
760
+ headers = {
761
+ "rid" : "session" ,
762
+ },
763
+ )
764
+ assert response .status_code == 200
765
+ assert list (response .json ()) == ["session" , "user_id" ]
766
+
767
+ # Clear access tokens:
768
+ driver_config_client .cookies .clear ()
769
+
770
+ # Without access tokens:
771
+ # without RID:
772
+ response = driver_config_client .post ("/sessioninfo-optional" )
773
+ assert response .status_code == 200
774
+ assert response .json () == {"message" : "no session" }
775
+
776
+ # with RID:
777
+ response = driver_config_client .post (
778
+ "/sessioninfo-optional" ,
779
+ headers = {
780
+ "rid" : "session" ,
781
+ },
782
+ )
783
+ assert response .status_code == 200
784
+ assert response .json () == {"message" : "no session" }
0 commit comments