Skip to content

Commit c0c2395

Browse files
committed
refactor: Implement and use custom JWKClient
1 parent 7fb96c3 commit c0c2395

File tree

5 files changed

+136
-36
lines changed

5 files changed

+136
-36
lines changed

supertokens_python/recipe/session/access_token.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from supertokens_python.utils import get_timestamp_ms
2323

2424
from .exceptions import raise_try_refresh_token_exception
25+
from .jwks import JWKClient
2526
from .jwt import ParsedJWTInfo
2627

2728

@@ -45,37 +46,37 @@ def sanitize_number(n: Any) -> Union[Union[int, float], None]:
4546

4647
def get_info_from_access_token(
4748
jwt_info: ParsedJWTInfo,
48-
jwk_clients: List[jwt.PyJWKClient],
49+
jwk_clients: List[JWKClient],
4950
do_anti_csrf_check: bool,
5051
):
52+
# TODO: Add different tests to verify this works as expected
5153
try:
5254
payload: Optional[Dict[str, Any]] = None
5355

56+
if jwt_info.version < 3:
57+
# It won't have kid. So we'll have to try the token against all the keys from all the jwk_clients
58+
for client in jwk_clients:
59+
keys = client.get_latest_keys()
60+
61+
for k in keys:
62+
try:
63+
payload = jwt.decode(jwt_info.raw_token_string, str(k.key), algorithms=["RS256"]) # type: ignore
64+
break
65+
except DecodeError:
66+
pass
67+
68+
if payload is None:
69+
raise PyJWKClientError("No key found")
70+
71+
# Came here means token is v3 or above
5472
for client in jwk_clients:
55-
try:
56-
# TODO: verify this works as expected
57-
signing_key: str = client.get_signing_key_from_jwt(jwt_info.raw_token_string).key # type: ignore
58-
payload = jwt.decode( # type: ignore
59-
jwt_info.raw_token_string,
60-
signing_key,
61-
algorithms=["RS256"],
62-
options={"verify_signature": True, "verify_exp": True},
63-
)
64-
except PyJWKClientError as e:
65-
# If no kid is present, this error is thrown
66-
# So we'll have to try the token against all the keys if it's v2
67-
if jwt_info.version == 2:
68-
for client in jwk_clients:
69-
keys = client.get_signing_keys()
70-
for k in keys:
71-
try:
72-
payload = jwt.decode(jwt_info.raw_token_string, str(k.key), algorithms=["RS256"]) # type: ignore
73-
except DecodeError:
74-
pass
75-
if payload is None:
76-
raise e
77-
except DecodeError as e:
78-
raise e
73+
matching_key = client.get_matching_key_from_jwt(jwt_info.raw_token_string)
74+
payload = jwt.decode( # type: ignore
75+
jwt_info.raw_token_string,
76+
matching_key,
77+
algorithms=["RS256"],
78+
options={"verify_signature": True, "verify_exp": True},
79+
)
7980

8081
assert payload is not None
8182

@@ -90,7 +91,7 @@ def get_info_from_access_token(
9091
user_id = sanitize_string(payload.get("sub"))
9192
expiry_time = sanitize_number(
9293
payload.get("exp", 0) * 1000
93-
) # FIXME: Is adding 0 as default okay?
94+
) # FIXME: Is using 0 as default okay?
9495
time_created = sanitize_number(payload.get("iat", 0) * 1000)
9596
user_data = payload
9697

supertokens_python/recipe/session/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,5 @@
3131
ACCESS_CONTROL_EXPOSE_HEADERS = "Access-Control-Expose-Headers"
3232

3333
available_token_transfer_methods: List[TokenTransferMethod] = ["cookie", "header"]
34+
35+
JWKCacheMaxAgeInMs = 60000
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import json
2+
import urllib.request
3+
from typing import List, Optional
4+
from urllib.error import URLError
5+
6+
from jwt import PyJWK, PyJWKSet
7+
from jwt.api_jwt import decode_complete as decode_token # type: ignore
8+
9+
from supertokens_python.utils import get_timestamp_ms
10+
11+
from .constants import JWKCacheMaxAgeInMs
12+
13+
14+
class JWKClient:
15+
def __init__(
16+
self,
17+
uri: str,
18+
cooldown_duration: int = 500,
19+
cache_max_age: int = JWKCacheMaxAgeInMs,
20+
):
21+
"""A client for retrieving JSON Web Key Sets (JWKS) from a given URI.
22+
23+
Args:
24+
uri (str): The URI of the JWKS.
25+
cooldown_duration (int, optional): The cooldown duration in ms. Defaults to 500.
26+
cache_max_age (int, optional): The cache max age in ms. Defaults to 300.
27+
28+
Note: The JSON Web Key Set is fetched when no key matches the selection
29+
process but only as frequently as the `self.cooldown_duration` option
30+
allows to prevent abuse. The `self.cache_max_age` option is used to
31+
determine how long the JWKS is cached for.
32+
33+
Whenever you make a call to `get_signing_key_from_jwt`, the JWKS
34+
is fetched if it is older than `self.cache_max_age` ms unless
35+
cooldown is active.
36+
"""
37+
self.uri = uri
38+
self.cooldown_duration = cooldown_duration
39+
self.cache_max_age = cache_max_age
40+
self.timeout_sec = 5
41+
self.last_fetch_time: int = 0
42+
self.jwk_set: Optional[PyJWKSet] = None
43+
44+
def reload(self):
45+
try:
46+
with urllib.request.urlopen(self.uri, timeout=self.timeout_sec) as response:
47+
self.jwk_set = PyJWKSet.from_dict(json.load(response)) # type: ignore
48+
self.last_fetch_time = get_timestamp_ms()
49+
except URLError as e:
50+
raise JWKSRequestError(f'Failed to fetch data from the url, err: "{e}"')
51+
52+
def is_cooling_down(self) -> bool:
53+
return (self.last_fetch_time > 0) and (
54+
get_timestamp_ms() - self.last_fetch_time < self.cooldown_duration
55+
)
56+
57+
def is_fresh(self) -> bool:
58+
return (self.last_fetch_time > 0) and (
59+
get_timestamp_ms() - self.last_fetch_time < self.cache_max_age
60+
)
61+
62+
def get_latest_keys(self) -> List[PyJWK]:
63+
if self.jwk_set is None or not self.is_fresh():
64+
self.reload()
65+
66+
assert self.jwk_set is not None
67+
68+
all_keys: List[PyJWK] = self.jwk_set.keys # type: ignore
69+
70+
return all_keys
71+
72+
def get_matching_key_from_jwt(self, token: str) -> str:
73+
header = decode_token(token, options={"verify_signature": False})["header"]
74+
kid: str = header["kid"] # type: ignore
75+
76+
if self.jwk_set is None or not self.is_fresh():
77+
self.reload()
78+
79+
assert self.jwk_set is not None
80+
81+
try:
82+
return str(self.jwk_set[kid].key) # type: ignore
83+
except KeyError:
84+
if not self.is_cooling_down():
85+
# One more attempt to fetch the latest keys
86+
# and then try to find the key again.
87+
self.reload()
88+
try:
89+
return str(self.jwk_set[kid].key) # type: ignore
90+
except KeyError:
91+
pass
92+
93+
raise JWKSKeyNotFoundError("No key found for the given kid")
94+
95+
96+
class JWKSKeyNotFoundError(Exception):
97+
pass
98+
99+
100+
class JWKSRequestError(Exception):
101+
pass

supertokens_python/recipe/session/recipe_implementation.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
import json
1717
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
1818

19-
import jwt
20-
2119
from supertokens_python.logger import log_debug_message
2220
from supertokens_python.normalised_url_path import NormalisedURLPath
2321
from supertokens_python.utils import resolve
@@ -48,16 +46,17 @@
4846
SessionInformationResult,
4947
SessionObj,
5048
)
49+
from .jwks import JWKClient
5150
from .jwt import ParsedJWTInfo, parse_jwt_without_signature_verification
5251
from .session_class import Session
5352
from .utils import SessionConfig, validate_claims_in_payload
5453

55-
5654
if TYPE_CHECKING:
5755
from typing import List, Union
5856
from supertokens_python import AppInfo
5957
from supertokens_python.querier import Querier
6058

59+
from .constants import JWKCacheMaxAgeInMs
6160
from .interfaces import SessionContainer
6261

6362
protected_props = [
@@ -79,10 +78,9 @@ def __init__(self, querier: Querier, config: SessionConfig, app_info: AppInfo):
7978
self.app_info = app_info
8079

8180
@property
82-
def JWK_Clients(self) -> List[jwt.PyJWKClient]:
83-
# FIXME: Find params OR Implement caching
81+
def JWK_clients(self) -> List[JWKClient]:
8482
return [
85-
jwt.PyJWKClient(uri)
83+
JWKClient(uri, cooldown_duration=500, cache_max_age=JWKCacheMaxAgeInMs)
8684
for uri in self.querier.get_all_core_urls_for_path(".well-known/jwks.json")
8785
]
8886

supertokens_python/recipe/session/session_functions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from supertokens_python.recipe.session.interfaces import SessionInformationResult
2020

2121
from .access_token import get_info_from_access_token
22+
from .constants import JWKCacheMaxAgeInMs
2223
from .jwt import ParsedJWTInfo
2324

2425
if TYPE_CHECKING:
@@ -28,9 +29,6 @@
2829
from supertokens_python.normalised_url_path import NormalisedURLPath
2930
from supertokens_python.process_state import AllowedProcessStates, ProcessState
3031

31-
JWKCacheMaxAgeInMs = 60000
32-
33-
3432
from .exceptions import (
3533
TryRefreshTokenError,
3634
raise_token_theft_exception,
@@ -83,7 +81,7 @@ async def get_session(
8381
try:
8482
access_token_info = get_info_from_access_token(
8583
parsed_access_token,
86-
recipe_implementation.JWK_Clients, # FIXME: Use JWKS
84+
recipe_implementation.JWK_clients,
8785
config.anti_csrf == "VIA_TOKEN" and do_anti_csrf_check,
8886
)
8987

0 commit comments

Comments
 (0)