Skip to content

Commit a87c84a

Browse files
authored
fix: configurable jwk cache duration (#512)
* fix: configurable jwk cache duration * fix: rename * fix: rename * fix: rename
1 parent 92bf3c0 commit a87c84a

File tree

11 files changed

+56
-37
lines changed

11 files changed

+56
-37
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
### Changes
1414

1515
- `refresh_post` and `refresh_session` now clears all user tokens upon CSRF failures and if no tokens are found. See the latest comment on https://github.com/supertokens/supertokens-node/issues/141 for more details.
16+
- Adds `jwks_refresh_interval_sec` config to `session.init` to set the default JWKS cache duration. The default is 4 hours.
1617

1718
## [0.23.0] - 2024-06-24
1819

supertokens_python/recipe/jwt/recipe_implementation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,8 @@ async def get_jwks(self, user_context: Dict[str, Any]) -> GetJWKSResult:
8585
pattern = r",?\s*max-age=(\d+)(?:,|$)"
8686
max_age_header = re.match(pattern, cache_control)
8787
if max_age_header is not None:
88-
validity_in_secs = int(max_age_header.group(1))
8988
try:
90-
validity_in_secs = int(validity_in_secs)
89+
validity_in_secs = int(max_age_header.group(1))
9190
except Exception:
9291
validity_in_secs = DEFAULT_JWKS_MAX_AGE
9392

supertokens_python/recipe/session/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def init(
5151
invalid_claim_status_code: Union[int, None] = None,
5252
use_dynamic_access_token_signing_key: Union[bool, None] = None,
5353
expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None,
54+
jwks_refresh_interval_sec: Union[int, None] = None,
5455
) -> Callable[[AppInfo], RecipeModule]:
5556
return SessionRecipe.init(
5657
cookie_domain,
@@ -65,4 +66,5 @@ def init(
6566
invalid_claim_status_code,
6667
use_dynamic_access_token_signing_key,
6768
expose_access_token_to_frontend_in_cookie_based_auth,
69+
jwks_refresh_interval_sec,
6870
)

supertokens_python/recipe/session/access_token.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from jwt.exceptions import DecodeError
2020

2121
from supertokens_python.logger import log_debug_message
22+
from supertokens_python.recipe.session.utils import SessionConfig
2223
from supertokens_python.utils import get_timestamp_ms
2324

2425
from .exceptions import raise_try_refresh_token_exception
@@ -48,6 +49,7 @@ def sanitize_number(n: Any) -> Union[Union[int, float], None]:
4849

4950

5051
def get_info_from_access_token(
52+
config: SessionConfig,
5153
jwt_info: ParsedJWTInfo,
5254
do_anti_csrf_check: bool,
5355
):
@@ -60,7 +62,7 @@ def get_info_from_access_token(
6062
)
6163

6264
if jwt_info.version >= 3:
63-
matching_keys = get_latest_keys(jwt_info.kid)
65+
matching_keys = get_latest_keys(config, jwt_info.kid)
6466
payload = jwt.decode( # type: ignore
6567
jwt_info.raw_token_string,
6668
matching_keys[0].key, # type: ignore
@@ -70,7 +72,7 @@ def get_info_from_access_token(
7072
else:
7173
# It won't have kid. So we'll have to try the token against all the keys from all the jwk_clients
7274
# If any of them work, we'll use that payload
73-
for k in get_latest_keys():
75+
for k in get_latest_keys(config):
7476
try:
7577
payload = jwt.decode( # type: ignore
7678
jwt_info.raw_token_string,

supertokens_python/recipe/session/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434
available_token_transfer_methods: List[TokenTransferMethod] = ["cookie", "header"]
3535

36-
JWKCacheMaxAgeInMs = 60 * 1000 # 60s
3736
protected_props = [
3837
"sub",
3938
"iat",

supertokens_python/recipe/session/jwks.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,31 +19,32 @@
1919

2020
from jwt import PyJWK, PyJWKSet
2121

22-
from .constants import JWKCacheMaxAgeInMs
23-
22+
from supertokens_python.recipe.session.utils import SessionConfig
2423
from supertokens_python.utils import RWMutex, RWLockContext, get_timestamp_ms
2524
from supertokens_python.querier import Querier
2625
from supertokens_python.logger import log_debug_message
2726

2827

2928
class JWKSConfigType(TypedDict):
30-
cache_max_age: int
3129
request_timeout: int
3230

3331

3432
JWKSConfig: JWKSConfigType = {
35-
"cache_max_age": JWKCacheMaxAgeInMs,
3633
"request_timeout": 10000, # 10s
3734
}
3835

3936

4037
class CachedKeys:
41-
def __init__(self, keys: List[PyJWK]):
38+
def __init__(self, keys: List[PyJWK], refresh_interval_sec: int):
4239
self.keys = keys
4340
self.last_refresh_time = get_timestamp_ms()
41+
self.refresh_interval_sec = refresh_interval_sec
4442

4543
def is_fresh(self):
46-
return get_timestamp_ms() - self.last_refresh_time < JWKSConfig["cache_max_age"]
44+
return (
45+
get_timestamp_ms() - self.last_refresh_time
46+
< self.refresh_interval_sec * 1000
47+
)
4748

4849

4950
cached_keys: Optional[CachedKeys] = None
@@ -86,7 +87,7 @@ def find_matching_keys(
8687
return None
8788

8889

89-
def get_latest_keys(kid: Optional[str] = None) -> List[PyJWK]:
90+
def get_latest_keys(config: SessionConfig, kid: Optional[str] = None) -> List[PyJWK]:
9091
global cached_keys
9192

9293
if environ.get("SUPERTOKENS_ENV") == "testing":
@@ -134,7 +135,7 @@ def get_latest_keys(kid: Optional[str] = None) -> List[PyJWK]:
134135
last_error = e
135136

136137
if cached_jwks is not None: # we found a valid JWKS
137-
cached_keys = CachedKeys(cached_jwks)
138+
cached_keys = CachedKeys(cached_jwks, config.jwks_refresh_interval_sec)
138139
log_debug_message("Returning JWKS from fetch")
139140
matching_keys = find_matching_keys(get_cached_keys(), kid)
140141
if matching_keys is not None:

supertokens_python/recipe/session/recipe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(
9292
invalid_claim_status_code: Union[int, None] = None,
9393
use_dynamic_access_token_signing_key: Union[bool, None] = None,
9494
expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None,
95+
jwks_refresh_interval_sec: Union[int, None] = None,
9596
):
9697
super().__init__(recipe_id, app_info)
9798
self.config = validate_and_normalise_user_input(
@@ -108,6 +109,7 @@ def __init__(
108109
invalid_claim_status_code,
109110
use_dynamic_access_token_signing_key,
110111
expose_access_token_to_frontend_in_cookie_based_auth,
112+
jwks_refresh_interval_sec,
111113
)
112114
self.openid_recipe = OpenIdRecipe(
113115
recipe_id,
@@ -307,6 +309,7 @@ def init(
307309
invalid_claim_status_code: Union[int, None] = None,
308310
use_dynamic_access_token_signing_key: Union[bool, None] = None,
309311
expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None,
312+
jwks_refresh_interval_sec: Union[int, None] = None,
310313
):
311314
def func(app_info: AppInfo):
312315
if SessionRecipe.__instance is None:
@@ -325,6 +328,7 @@ def func(app_info: AppInfo):
325328
invalid_claim_status_code,
326329
use_dynamic_access_token_signing_key,
327330
expose_access_token_to_frontend_in_cookie_based_auth,
331+
jwks_refresh_interval_sec,
328332
)
329333
return SessionRecipe.__instance
330334
raise_general_exception(

supertokens_python/recipe/session/session_functions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from supertokens_python.recipe.session.interfaces import SessionInformationResult
2020

2121
from .access_token import get_info_from_access_token
22-
from .constants import JWKCacheMaxAgeInMs
2322
from .jwt import ParsedJWTInfo
2423

2524
if TYPE_CHECKING:
@@ -159,6 +158,7 @@ async def get_session(
159158

160159
try:
161160
access_token_info = get_info_from_access_token(
161+
config,
162162
parsed_access_token,
163163
config.anti_csrf_function_or_string == "VIA_TOKEN" and do_anti_csrf_check,
164164
)
@@ -198,7 +198,8 @@ async def get_session(
198198

199199
# We check if the token was created since the last time we refreshed the keys from the core
200200
# Since we do not know the exact timing of the last refresh, we check against the max age
201-
if time_created <= time.time() - JWKCacheMaxAgeInMs:
201+
202+
if time_created <= time.time() - config.jwks_refresh_interval_sec:
202203
raise e
203204
else:
204205
# Since v3 (and above) tokens contain a kid we can trust the cache refresh mechanism built on top of the pyjwt lib

supertokens_python/recipe/session/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ def __init__(
383383
invalid_claim_status_code: int,
384384
use_dynamic_access_token_signing_key: bool,
385385
expose_access_token_to_frontend_in_cookie_based_auth: bool,
386+
jwks_refresh_interval_sec: int,
386387
):
387388
self.session_expired_status_code = session_expired_status_code
388389
self.invalid_claim_status_code = invalid_claim_status_code
@@ -402,6 +403,7 @@ def __init__(
402403
self.override = override
403404
self.framework = framework
404405
self.mode = mode
406+
self.jwks_refresh_interval_sec = jwks_refresh_interval_sec
405407

406408

407409
def validate_and_normalise_user_input(
@@ -424,6 +426,7 @@ def validate_and_normalise_user_input(
424426
invalid_claim_status_code: Union[int, None] = None,
425427
use_dynamic_access_token_signing_key: Union[bool, None] = None,
426428
expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None,
429+
jwks_refresh_interval_sec: Union[int, None] = None,
427430
):
428431
_ = cookie_same_site # we have this otherwise pylint complains that cookie_same_site is unused, but it is being used in the get_cookie_same_site function.
429432
if anti_csrf not in {"VIA_TOKEN", "VIA_CUSTOM_HEADER", "NONE", None}:
@@ -531,6 +534,9 @@ def anti_csrf_function(
531534
if anti_csrf is not None:
532535
anti_csrf_function_or_string = anti_csrf
533536

537+
if jwks_refresh_interval_sec is None:
538+
jwks_refresh_interval_sec = 4 * 3600 # 4 hours
539+
534540
return SessionConfig(
535541
app_info.api_base_path.append(NormalisedURLPath(SESSION_REFRESH)),
536542
cookie_domain,
@@ -547,6 +553,7 @@ def anti_csrf_function(
547553
invalid_claim_status_code,
548554
use_dynamic_access_token_signing_key,
549555
expose_access_token_to_frontend_in_cookie_based_auth,
556+
jwks_refresh_interval_sec,
550557
)
551558

552559

tests/sessions/test_access_token_version.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from supertokens_python.recipe.session.access_token import (
1717
validate_access_token_structure,
1818
)
19+
from supertokens_python.recipe.session.recipe import SessionRecipe
1920
from tests.utils import get_st_init_args, setup_function, start_st, teardown_function
2021

2122
_ = setup_function # type:ignore
@@ -38,6 +39,7 @@ async def test_access_token_v4():
3839
parsed_info = parse_jwt_without_signature_verification(access_token)
3940

4041
res = get_info_from_access_token(
42+
SessionRecipe.get_instance().config,
4143
parsed_info,
4244
False,
4345
)

0 commit comments

Comments
 (0)