Skip to content

fix: configurable jwk cache duration #512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changes

- `refresh_post` and `refresh_session` now clears all user tokens upon CSRF failures and if no tokens are found. See the latest comment on https://github.com/supertokens/supertokens-node/issues/141 for more details.
- Adds `jwks_refresh_interval_sec` config to `session.init` to set the default JWKS cache duration. The default is 4 hours.

## [0.23.0] - 2024-06-24

Expand Down
3 changes: 1 addition & 2 deletions supertokens_python/recipe/jwt/recipe_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,8 @@ async def get_jwks(self, user_context: Dict[str, Any]) -> GetJWKSResult:
pattern = r",?\s*max-age=(\d+)(?:,|$)"
max_age_header = re.match(pattern, cache_control)
if max_age_header is not None:
validity_in_secs = int(max_age_header.group(1))
try:
validity_in_secs = int(validity_in_secs)
validity_in_secs = int(max_age_header.group(1))
except Exception:
validity_in_secs = DEFAULT_JWKS_MAX_AGE

Expand Down
2 changes: 2 additions & 0 deletions supertokens_python/recipe/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def init(
invalid_claim_status_code: Union[int, None] = None,
use_dynamic_access_token_signing_key: Union[bool, None] = None,
expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None,
jwks_refresh_interval_sec: Union[int, None] = None,
) -> Callable[[AppInfo], RecipeModule]:
return SessionRecipe.init(
cookie_domain,
Expand All @@ -65,4 +66,5 @@ def init(
invalid_claim_status_code,
use_dynamic_access_token_signing_key,
expose_access_token_to_frontend_in_cookie_based_auth,
jwks_refresh_interval_sec,
)
6 changes: 4 additions & 2 deletions supertokens_python/recipe/session/access_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from jwt.exceptions import DecodeError

from supertokens_python.logger import log_debug_message
from supertokens_python.recipe.session.utils import SessionConfig
from supertokens_python.utils import get_timestamp_ms

from .exceptions import raise_try_refresh_token_exception
Expand Down Expand Up @@ -48,6 +49,7 @@ def sanitize_number(n: Any) -> Union[Union[int, float], None]:


def get_info_from_access_token(
config: SessionConfig,
jwt_info: ParsedJWTInfo,
do_anti_csrf_check: bool,
):
Expand All @@ -60,7 +62,7 @@ def get_info_from_access_token(
)

if jwt_info.version >= 3:
matching_keys = get_latest_keys(jwt_info.kid)
matching_keys = get_latest_keys(config, jwt_info.kid)
payload = jwt.decode( # type: ignore
jwt_info.raw_token_string,
matching_keys[0].key, # type: ignore
Expand All @@ -70,7 +72,7 @@ def get_info_from_access_token(
else:
# It won't have kid. So we'll have to try the token against all the keys from all the jwk_clients
# If any of them work, we'll use that payload
for k in get_latest_keys():
for k in get_latest_keys(config):
try:
payload = jwt.decode( # type: ignore
jwt_info.raw_token_string,
Expand Down
1 change: 0 additions & 1 deletion supertokens_python/recipe/session/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

available_token_transfer_methods: List[TokenTransferMethod] = ["cookie", "header"]

JWKCacheMaxAgeInMs = 60 * 1000 # 60s
protected_props = [
"sub",
"iat",
Expand Down
17 changes: 9 additions & 8 deletions supertokens_python/recipe/session/jwks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,32 @@

from jwt import PyJWK, PyJWKSet

from .constants import JWKCacheMaxAgeInMs

from supertokens_python.recipe.session.utils import SessionConfig
from supertokens_python.utils import RWMutex, RWLockContext, get_timestamp_ms
from supertokens_python.querier import Querier
from supertokens_python.logger import log_debug_message


class JWKSConfigType(TypedDict):
cache_max_age: int
request_timeout: int


JWKSConfig: JWKSConfigType = {
"cache_max_age": JWKCacheMaxAgeInMs,
"request_timeout": 10000, # 10s
}


class CachedKeys:
def __init__(self, keys: List[PyJWK]):
def __init__(self, keys: List[PyJWK], refresh_interval_sec: int):
self.keys = keys
self.last_refresh_time = get_timestamp_ms()
self.refresh_interval_sec = refresh_interval_sec

def is_fresh(self):
return get_timestamp_ms() - self.last_refresh_time < JWKSConfig["cache_max_age"]
return (
get_timestamp_ms() - self.last_refresh_time
< self.refresh_interval_sec * 1000
)


cached_keys: Optional[CachedKeys] = None
Expand Down Expand Up @@ -86,7 +87,7 @@ def find_matching_keys(
return None


def get_latest_keys(kid: Optional[str] = None) -> List[PyJWK]:
def get_latest_keys(config: SessionConfig, kid: Optional[str] = None) -> List[PyJWK]:
global cached_keys

if environ.get("SUPERTOKENS_ENV") == "testing":
Expand Down Expand Up @@ -134,7 +135,7 @@ def get_latest_keys(kid: Optional[str] = None) -> List[PyJWK]:
last_error = e

if cached_jwks is not None: # we found a valid JWKS
cached_keys = CachedKeys(cached_jwks)
cached_keys = CachedKeys(cached_jwks, config.jwks_refresh_interval_sec)
log_debug_message("Returning JWKS from fetch")
matching_keys = find_matching_keys(get_cached_keys(), kid)
if matching_keys is not None:
Expand Down
4 changes: 4 additions & 0 deletions supertokens_python/recipe/session/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
invalid_claim_status_code: Union[int, None] = None,
use_dynamic_access_token_signing_key: Union[bool, None] = None,
expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None,
jwks_refresh_interval_sec: Union[int, None] = None,
):
super().__init__(recipe_id, app_info)
self.config = validate_and_normalise_user_input(
Expand All @@ -108,6 +109,7 @@ def __init__(
invalid_claim_status_code,
use_dynamic_access_token_signing_key,
expose_access_token_to_frontend_in_cookie_based_auth,
jwks_refresh_interval_sec,
)
self.openid_recipe = OpenIdRecipe(
recipe_id,
Expand Down Expand Up @@ -307,6 +309,7 @@ def init(
invalid_claim_status_code: Union[int, None] = None,
use_dynamic_access_token_signing_key: Union[bool, None] = None,
expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None,
jwks_refresh_interval_sec: Union[int, None] = None,
):
def func(app_info: AppInfo):
if SessionRecipe.__instance is None:
Expand All @@ -325,6 +328,7 @@ def func(app_info: AppInfo):
invalid_claim_status_code,
use_dynamic_access_token_signing_key,
expose_access_token_to_frontend_in_cookie_based_auth,
jwks_refresh_interval_sec,
)
return SessionRecipe.__instance
raise_general_exception(
Expand Down
5 changes: 3 additions & 2 deletions supertokens_python/recipe/session/session_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from supertokens_python.recipe.session.interfaces import SessionInformationResult

from .access_token import get_info_from_access_token
from .constants import JWKCacheMaxAgeInMs
from .jwt import ParsedJWTInfo

if TYPE_CHECKING:
Expand Down Expand Up @@ -159,6 +158,7 @@ async def get_session(

try:
access_token_info = get_info_from_access_token(
config,
parsed_access_token,
config.anti_csrf_function_or_string == "VIA_TOKEN" and do_anti_csrf_check,
)
Expand Down Expand Up @@ -198,7 +198,8 @@ async def get_session(

# We check if the token was created since the last time we refreshed the keys from the core
# Since we do not know the exact timing of the last refresh, we check against the max age
if time_created <= time.time() - JWKCacheMaxAgeInMs:

if time_created <= time.time() - config.jwks_refresh_interval_sec:
raise e
else:
# Since v3 (and above) tokens contain a kid we can trust the cache refresh mechanism built on top of the pyjwt lib
Expand Down
7 changes: 7 additions & 0 deletions supertokens_python/recipe/session/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def __init__(
invalid_claim_status_code: int,
use_dynamic_access_token_signing_key: bool,
expose_access_token_to_frontend_in_cookie_based_auth: bool,
jwks_refresh_interval_sec: int,
):
self.session_expired_status_code = session_expired_status_code
self.invalid_claim_status_code = invalid_claim_status_code
Expand All @@ -402,6 +403,7 @@ def __init__(
self.override = override
self.framework = framework
self.mode = mode
self.jwks_refresh_interval_sec = jwks_refresh_interval_sec


def validate_and_normalise_user_input(
Expand All @@ -424,6 +426,7 @@ def validate_and_normalise_user_input(
invalid_claim_status_code: Union[int, None] = None,
use_dynamic_access_token_signing_key: Union[bool, None] = None,
expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None,
jwks_refresh_interval_sec: Union[int, None] = None,
):
_ = cookie_same_site # we have this otherwise pylint complains that cookie_same_site is unused, but it is being used in the get_cookie_same_site function.
if anti_csrf not in {"VIA_TOKEN", "VIA_CUSTOM_HEADER", "NONE", None}:
Expand Down Expand Up @@ -531,6 +534,9 @@ def anti_csrf_function(
if anti_csrf is not None:
anti_csrf_function_or_string = anti_csrf

if jwks_refresh_interval_sec is None:
jwks_refresh_interval_sec = 4 * 3600 # 4 hours

return SessionConfig(
app_info.api_base_path.append(NormalisedURLPath(SESSION_REFRESH)),
cookie_domain,
Expand All @@ -547,6 +553,7 @@ def anti_csrf_function(
invalid_claim_status_code,
use_dynamic_access_token_signing_key,
expose_access_token_to_frontend_in_cookie_based_auth,
jwks_refresh_interval_sec,
)


Expand Down
2 changes: 2 additions & 0 deletions tests/sessions/test_access_token_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from supertokens_python.recipe.session.access_token import (
validate_access_token_structure,
)
from supertokens_python.recipe.session.recipe import SessionRecipe
from tests.utils import get_st_init_args, setup_function, start_st, teardown_function

_ = setup_function # type:ignore
Expand All @@ -38,6 +39,7 @@ async def test_access_token_v4():
parsed_info = parse_jwt_without_signature_verification(access_token)

res = get_info_from_access_token(
SessionRecipe.get_instance().config,
parsed_info,
False,
)
Expand Down
Loading
Loading