Skip to content

Commit a74482d

Browse files
Merge pull request #429 from supertokens/feat/cache-control-jwks
feat: Add cache control for jwks endpoint
2 parents 364896d + 22d5c20 commit a74482d

File tree

11 files changed

+136
-22
lines changed

11 files changed

+136
-22
lines changed

CHANGELOG.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2929
- Dashboard APIs now return a status code `403` for all non-GET requests if the currently logged in Dashboard User is not listed in the `admins` array
3030
- Now ignoring protected props in the payload in `create_new_session` and `create_new_session_without_request_response`
3131

32-
## [0.15.3] - 2023-09-24
32+
## [0.15.3] - 2023-09-25
3333

3434
- Handle 429 rate limiting from SaaS core instances
35+
- Add `Cache-Control` header for jwks endpoint `/jwt/jwks.json`
36+
- Add `validity_in_secs` to the return value of overridable `get_jwks` recipe function.
37+
- This can be used to control the `Cache-Control` header mentioned above.
38+
- It defaults to `60` or the value set in the cache-control header returned by the core
39+
- This is optional (so you are not required to update your overrides). Returning `None` means that the header won't be set
40+
3541

3642
## [0.15.2] - 2023-09-23
3743

supertokens_python/querier.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ async def __get_headers_with_api_version(self, path: NormalisedURLPath):
160160

161161
async def send_get_request(
162162
self, path: NormalisedURLPath, params: Union[Dict[str, Any], None] = None
163-
):
163+
) -> Dict[str, Any]:
164164
if params is None:
165165
params = {}
166166

@@ -180,7 +180,7 @@ async def send_post_request(
180180
path: NormalisedURLPath,
181181
data: Union[Dict[str, Any], None] = None,
182182
test: bool = False,
183-
):
183+
) -> Dict[str, Any]:
184184
if data is None:
185185
data = {}
186186

@@ -207,7 +207,7 @@ async def f(url: str, method: str) -> Response:
207207

208208
async def send_delete_request(
209209
self, path: NormalisedURLPath, params: Union[Dict[str, Any], None] = None
210-
):
210+
) -> Dict[str, Any]:
211211
if params is None:
212212
params = {}
213213

@@ -224,7 +224,7 @@ async def f(url: str, method: str) -> Response:
224224

225225
async def send_put_request(
226226
self, path: NormalisedURLPath, data: Union[Dict[str, Any], None] = None
227-
):
227+
) -> Dict[str, Any]:
228228
if data is None:
229229
data = {}
230230

@@ -262,7 +262,7 @@ async def __send_request_helper(
262262
http_function: Callable[[str, str], Awaitable[Response]],
263263
no_of_tries: int,
264264
retry_info_map: Optional[Dict[str, int]] = None,
265-
) -> Any:
265+
) -> Dict[str, Any]:
266266
if no_of_tries == 0:
267267
raise_general_exception("No SuperTokens core available to query")
268268

@@ -321,10 +321,14 @@ async def __send_request_helper(
321321
+ response.text # type: ignore
322322
)
323323

324+
res: Dict[str, Any] = {"_headers": dict(response.headers)}
325+
324326
try:
325-
return response.json()
327+
res.update(response.json())
326328
except JSONDecodeError:
327-
return response.text
329+
res["_text"] = response.text
330+
331+
return res
328332

329333
except (ConnectionError, NetworkError, ConnectTimeout) as _:
330334
return await self.__send_request_helper(

supertokens_python/recipe/jwt/api/implementation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,10 @@ async def jwks_get(
2525
self, api_options: APIOptions, user_context: Dict[str, Any]
2626
) -> JWKSGetResponse:
2727
response = await api_options.recipe_implementation.get_jwks(user_context)
28+
29+
if response.validity_in_secs is not None:
30+
api_options.response.set_header(
31+
"Cache-Control", f"max-age={response.validity_in_secs}, must-revalidate"
32+
)
33+
2834
return JWKSGetResponse(response.keys)

supertokens_python/recipe/jwt/interfaces.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ class CreateJwtResultUnsupportedAlgorithm:
4040

4141

4242
class GetJWKSResult:
43-
def __init__(self, keys: List[JsonWebKey]):
43+
def __init__(self, keys: List[JsonWebKey], validity_in_secs: Optional[int]):
4444
self.keys = keys
45+
self.validity_in_secs = validity_in_secs
4546

4647

4748
class RecipeInterface(ABC):

supertokens_python/recipe/jwt/recipe_implementation.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from supertokens_python.normalised_url_path import NormalisedURLPath
1919
from supertokens_python.querier import Querier
20+
import re
2021

2122
if TYPE_CHECKING:
2223
from .utils import JWTConfig
@@ -32,6 +33,10 @@
3233
from .interfaces import JsonWebKey
3334

3435

36+
# This corresponds to the dynamicSigningKeyOverlapMS in the core
37+
DEFAULT_JWKS_MAX_AGE = 60
38+
39+
3540
class RecipeImplementation(RecipeInterface):
3641
def __init__(self, querier: Querier, config: JWTConfig, app_info: AppInfo):
3742
super().__init__()
@@ -69,11 +74,25 @@ async def get_jwks(self, user_context: Dict[str, Any]) -> GetJWKSResult:
6974
NormalisedURLPath("/.well-known/jwks.json"), {}
7075
)
7176

77+
validity_in_secs = DEFAULT_JWKS_MAX_AGE
78+
cache_control = response["_headers"].get("Cache-Control")
79+
80+
if cache_control is not None:
81+
pattern = r",?\s*max-age=(\d+)(?:,|$)"
82+
max_age_header = re.match(pattern, cache_control)
83+
if max_age_header is not None:
84+
validity_in_secs = int(max_age_header.group(1))
85+
try:
86+
validity_in_secs = int(validity_in_secs)
87+
except Exception:
88+
validity_in_secs = DEFAULT_JWKS_MAX_AGE
89+
7290
keys: List[JsonWebKey] = []
7391
for key in response["keys"]:
7492
keys.append(
7593
JsonWebKey(
7694
key["kty"], key["kid"], key["n"], key["e"], key["alg"], key["use"]
7795
)
7896
)
79-
return GetJWKSResult(keys)
97+
98+
return GetJWKSResult(keys, validity_in_secs)

supertokens_python/recipe/multitenancy/recipe_implementation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ async def associate_user_to_tenant(
253253
AssociateUserToTenantPhoneNumberAlreadyExistsError,
254254
AssociateUserToTenantThirdPartyUserAlreadyExistsError,
255255
]:
256-
response: Dict[str, Any] = await self.querier.send_post_request(
256+
response = await self.querier.send_post_request(
257257
NormalisedURLPath(
258258
f"{tenant_id or DEFAULT_TENANT_ID}/recipe/multitenancy/tenant/user"
259259
),

supertokens_python/recipe/session/recipe_implementation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ async def regenerate_access_token(
469469
) -> Union[RegenerateAccessTokenOkResult, None]:
470470
if new_access_token_payload is None:
471471
new_access_token_payload = {}
472-
response: Dict[str, Any] = await self.querier.send_post_request(
472+
response = await self.querier.send_post_request(
473473
NormalisedURLPath("/recipe/session/regenerate"),
474474
{"accessToken": access_token, "userDataInJWT": new_access_token_payload},
475475
)

supertokens_python/recipe/session/session_functions.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,6 @@ async def create_new_session(
122122
},
123123
)
124124

125-
response.pop("status", None)
126-
127125
return CreateOrRefreshAPIResponse(
128126
CreateOrRefreshAPIResponseSession(
129127
response["session"]["handle"],
@@ -281,7 +279,6 @@ async def get_session(
281279
NormalisedURLPath("/recipe/session/verify"), data
282280
)
283281
if response["status"] == "OK":
284-
response.pop("status", None)
285282
return GetSessionAPIResponse(
286283
GetSessionAPIResponseSession(
287284
response["session"]["handle"],
@@ -351,7 +348,6 @@ async def refresh_session(
351348
NormalisedURLPath("/recipe/session/refresh"), data
352349
)
353350
if response["status"] == "OK":
354-
response.pop("status", None)
355351
return CreateOrRefreshAPIResponse(
356352
CreateOrRefreshAPIResponseSession(
357353
response["session"]["handle"],

supertokens_python/recipe/userroles/recipe_implementation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ async def add_role_to_user(
5151
NormalisedURLPath(f"{tenant_id}/recipe/user/role"),
5252
params,
5353
)
54-
if response.get("status") == "OK":
54+
if response["status"] == "OK":
5555
return AddRoleToUserOkResult(
5656
did_user_already_have_role=response["didUserAlreadyHaveRole"]
5757
)
@@ -93,7 +93,7 @@ async def get_users_that_have_role(
9393
NormalisedURLPath(f"{tenant_id}/recipe/role/users"),
9494
params,
9595
)
96-
if response.get("status") == "OK":
96+
if response["status"] == "OK":
9797
return GetUsersThatHaveRoleOkResult(users=response["users"])
9898
return UnknownRoleError()
9999

@@ -115,7 +115,7 @@ async def get_permissions_for_role(
115115
response = await self.querier.send_get_request(
116116
NormalisedURLPath("/recipe/role/permissions"), params
117117
)
118-
if response.get("status") == "OK":
118+
if response["status"] == "OK":
119119
return GetPermissionsForRoleOkResult(permissions=response["permissions"])
120120
return UnknownRoleError()
121121

@@ -126,7 +126,7 @@ async def remove_permissions_from_role(
126126
response = await self.querier.send_post_request(
127127
NormalisedURLPath("/recipe/role/permissions/remove"), params
128128
)
129-
if response.get("status") == "OK":
129+
if response["status"] == "OK":
130130
return RemovePermissionsFromRoleOkResult()
131131
return UnknownRoleError()
132132

tests/jwt/test_get_JWKS.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616

1717
from _pytest.fixtures import fixture
1818
from fastapi import FastAPI
19+
from typing import Optional, Dict, Any
1920
from pytest import mark
2021
from starlette.requests import Request
2122
from starlette.testclient import TestClient
2223
from supertokens_python import InputAppInfo, SupertokensConfig, init
2324
from supertokens_python.framework.fastapi import get_middleware
2425
from supertokens_python.recipe import jwt
25-
from supertokens_python.recipe.jwt.interfaces import APIInterface
26+
from supertokens_python.recipe.jwt.interfaces import APIInterface, RecipeInterface
2627
from supertokens_python.recipe.session.asyncio import create_new_session
2728
from tests.utils import clean_st, reset, setup_st, start_st
2829

@@ -83,6 +84,20 @@ async def test_that_default_getJWKS_api_does_not_work_when_disabled(
8384

8485

8586
async def test_that_default_getJWKS_works_fine(driver_config_client: TestClient):
87+
custom_validity: Optional[int] = -1 # -1 means no override
88+
89+
def func_override(oi: RecipeInterface):
90+
oi_get_jwks = oi.get_jwks
91+
92+
async def get_jwks(user_context: Dict[str, Any]):
93+
res = await oi_get_jwks(user_context)
94+
if custom_validity != -1:
95+
res.validity_in_secs = custom_validity
96+
return res
97+
98+
oi.get_jwks = get_jwks
99+
return oi
100+
86101
init(
87102
supertokens_config=SupertokensConfig("http://localhost:3567"),
88103
app_info=InputAppInfo(
@@ -91,12 +106,37 @@ async def test_that_default_getJWKS_works_fine(driver_config_client: TestClient)
91106
website_domain="supertokens.io",
92107
),
93108
framework="fastapi",
94-
recipe_list=[jwt.init()],
109+
recipe_list=[jwt.init(override=jwt.OverrideConfig(functions=func_override))],
95110
)
96111
start_st()
97112

98113
response = driver_config_client.get(url="/auth/jwt/jwks.json")
99114

115+
# Default:
100116
assert response.status_code == 200
101117
data = response.json()
118+
assert data.keys() == {"keys"}
102119
assert len(data["keys"]) > 0
120+
assert data["keys"][0].keys() == {"kty", "kid", "n", "e", "alg", "use"}
121+
122+
assert response.headers["cache-control"] == "max-age=60, must-revalidate"
123+
124+
# Override cache control:
125+
custom_validity = 1
126+
response = driver_config_client.get(url="/auth/jwt/jwks.json")
127+
128+
assert response.status_code == 200
129+
data = response.json()
130+
assert len(data["keys"]) > 0
131+
132+
assert response.headers["cache-control"] == "max-age=1, must-revalidate"
133+
134+
# Disable cache control:
135+
custom_validity = None
136+
response = driver_config_client.get(url="/auth/jwt/jwks.json")
137+
138+
assert response.status_code == 200
139+
data = response.json()
140+
assert len(data["keys"]) > 0
141+
142+
assert "cache-control" not in response.headers

tests/test_querier.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import asyncio
2222
import respx
2323
import httpx
24+
import json
2425
from supertokens_python import init, SupertokensConfig
2526
from supertokens_python.querier import Querier, NormalisedURLPath
2627

@@ -148,3 +149,44 @@ async def call_api(id_: int):
148149
assert call_count2 == 6
149150

150151
assert api.call_count == 12
152+
153+
154+
async def test_querier_text_and_headers():
155+
args = get_st_init_args([session.init()])
156+
args["supertokens_config"] = SupertokensConfig("http://localhost:6789")
157+
init(**args) # type: ignore
158+
start_st()
159+
160+
Querier.api_version = "3.0"
161+
q = Querier.get_instance()
162+
163+
with respx_mock() as mocker:
164+
text = "foo"
165+
mocker.get("http://localhost:6789/text-api").mock(
166+
httpx.Response(200, text=text, headers={"greet": "hello"})
167+
)
168+
169+
res = await q.send_get_request(NormalisedURLPath("/text-api"), {})
170+
assert res == {
171+
"_text": "foo",
172+
"_headers": {
173+
"greet": "hello",
174+
"content-type": "text/plain; charset=utf-8",
175+
"content-length": str(len("foo")),
176+
},
177+
}
178+
179+
body = {"bar": "baz"}
180+
mocker.get("http://localhost:6789/json-api").mock(
181+
httpx.Response(200, json=body, headers={"greet": "hi"})
182+
)
183+
184+
res = await q.send_get_request(NormalisedURLPath("/json-api"), {})
185+
assert res == {
186+
"bar": "baz",
187+
"_headers": {
188+
"greet": "hi",
189+
"content-type": "application/json",
190+
"content-length": str(len(json.dumps(body))),
191+
},
192+
}

0 commit comments

Comments
 (0)