|
13 | 13 | # under the License.
|
14 | 14 |
|
15 | 15 | from datetime import datetime, timedelta
|
16 |
| -from typing import Any, Dict, List |
| 16 | +from typing import Any, Dict, List, Optional |
17 | 17 | from unittest.mock import MagicMock
|
18 | 18 |
|
19 |
| -from fastapi import FastAPI |
| 19 | +from fastapi import FastAPI, Depends |
20 | 20 | from fastapi.requests import Request
|
| 21 | +from fastapi.responses import JSONResponse |
21 | 22 | from fastapi.testclient import TestClient
|
22 | 23 | from pytest import fixture, mark
|
23 | 24 |
|
|
41 | 42 | merge_into_access_token_payload,
|
42 | 43 | update_session_data_in_database,
|
43 | 44 | )
|
44 |
| -from supertokens_python.recipe.session.interfaces import RecipeInterface |
| 45 | +from supertokens_python.recipe.session.interfaces import ( |
| 46 | + RecipeInterface, |
| 47 | + SessionContainer, |
| 48 | +) |
45 | 49 | from supertokens_python.recipe.session.jwt import (
|
46 | 50 | parse_jwt_without_signature_verification,
|
47 | 51 | )
|
|
52 | 56 | refresh_session,
|
53 | 57 | revoke_session,
|
54 | 58 | )
|
| 59 | +from supertokens_python.recipe.session.framework.fastapi import verify_session |
55 | 60 | from tests.utils import clean_st, reset, setup_st, start_st
|
56 | 61 |
|
57 | 62 | pytestmark = mark.asyncio
|
@@ -251,6 +256,12 @@ async def create_api(request: Request): # type: ignore
|
251 | 256 | await async_create_new_session(request, "test-user", {}, {})
|
252 | 257 | return ""
|
253 | 258 |
|
| 259 | + @app.post("/sessioninfo-optional") |
| 260 | + async def _session_info(s: Optional[SessionContainer] = Depends(verify_session(session_required=False))): # type: ignore |
| 261 | + if s is not None: |
| 262 | + return JSONResponse({"session": s.get_handle(), "user_id": s.get_user_id()}) |
| 263 | + return JSONResponse({"message": "no session"}) |
| 264 | + |
254 | 265 | return TestClient(app)
|
255 | 266 |
|
256 | 267 |
|
@@ -724,3 +735,49 @@ async def test_that_verify_session_doesnt_always_call_core():
|
724 | 735 | AllowedProcessStates.CALLING_SERVICE_IN_VERIFY
|
725 | 736 | in ProcessState.get_instance().history
|
726 | 737 | ) # Core got called this time
|
| 738 | + |
| 739 | + |
| 740 | +async def test_anti_csrf_header_via_custom_header_check_happens_only_when_access_token_is_provided( |
| 741 | + driver_config_client: TestClient, |
| 742 | +): |
| 743 | + args = get_st_init_args([session.init(anti_csrf="VIA_CUSTOM_HEADER", get_token_transfer_method=lambda *_: "cookie")]) # type: ignore |
| 744 | + init(**args) # type: ignore |
| 745 | + start_st() |
| 746 | + |
| 747 | + response = driver_config_client.post("/create") |
| 748 | + assert response.status_code == 200 |
| 749 | + |
| 750 | + # With access token: |
| 751 | + # without RID: |
| 752 | + response = driver_config_client.post("/sessioninfo-optional") |
| 753 | + assert response.status_code == 401 |
| 754 | + assert response.json() == {"message": "try refresh token"} |
| 755 | + |
| 756 | + # with RID: |
| 757 | + response = driver_config_client.post( |
| 758 | + "/sessioninfo-optional", |
| 759 | + headers={ |
| 760 | + "rid": "session", |
| 761 | + }, |
| 762 | + ) |
| 763 | + assert response.status_code == 200 |
| 764 | + assert list(response.json()) == ["session", "user_id"] |
| 765 | + |
| 766 | + # Clear access tokens: |
| 767 | + driver_config_client.cookies.clear() |
| 768 | + |
| 769 | + # Without access tokens: |
| 770 | + # without RID: |
| 771 | + response = driver_config_client.post("/sessioninfo-optional") |
| 772 | + assert response.status_code == 200 |
| 773 | + assert response.json() == {"message": "no session"} |
| 774 | + |
| 775 | + # with RID: |
| 776 | + response = driver_config_client.post( |
| 777 | + "/sessioninfo-optional", |
| 778 | + headers={ |
| 779 | + "rid": "session", |
| 780 | + }, |
| 781 | + ) |
| 782 | + assert response.status_code == 200 |
| 783 | + assert response.json() == {"message": "no session"} |
0 commit comments