Skip to content

Commit ea3915c

Browse files
committed
fix: Improve JWKClient and fix get_session of session_functions
1 parent c0c2395 commit ea3915c

File tree

5 files changed

+63
-51
lines changed

5 files changed

+63
-51
lines changed

supertokens_python/recipe/session/access_token.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from supertokens_python.utils import get_timestamp_ms
2323

2424
from .exceptions import raise_try_refresh_token_exception
25-
from .jwks import JWKClient
25+
from .jwks import JWKClient, JWKSRequestError
2626
from .jwt import ParsedJWTInfo
2727

2828

@@ -55,28 +55,36 @@ def get_info_from_access_token(
5555

5656
if jwt_info.version < 3:
5757
# It won't have kid. So we'll have to try the token against all the keys from all the jwk_clients
58+
# If any of them work, we'll use that payload
5859
for client in jwk_clients:
59-
keys = client.get_latest_keys()
60+
try:
61+
keys = client.get_latest_keys()
6062

61-
for k in keys:
62-
try:
63-
payload = jwt.decode(jwt_info.raw_token_string, str(k.key), algorithms=["RS256"]) # type: ignore
64-
break
65-
except DecodeError:
66-
pass
63+
for k in keys:
64+
try:
65+
payload = jwt.decode(jwt_info.raw_token_string, str(k.key), algorithms=["RS256"]) # type: ignore
66+
break
67+
except DecodeError:
68+
pass
69+
except JWKSRequestError:
70+
continue
71+
72+
if payload is not None:
73+
break
6774

6875
if payload is None:
6976
raise PyJWKClientError("No key found")
7077

71-
# Came here means token is v3 or above
72-
for client in jwk_clients:
73-
matching_key = client.get_matching_key_from_jwt(jwt_info.raw_token_string)
74-
payload = jwt.decode( # type: ignore
75-
jwt_info.raw_token_string,
76-
matching_key,
77-
algorithms=["RS256"],
78-
options={"verify_signature": True, "verify_exp": True},
79-
)
78+
elif jwt_info.version > 3:
79+
for client in jwk_clients:
80+
matching_key = client.get_matching_key_from_jwt(jwt_info.raw_token_string)
81+
payload = jwt.decode( # type: ignore
82+
jwt_info.raw_token_string,
83+
matching_key,
84+
algorithms=["RS256"],
85+
options={"verify_signature": True, "verify_exp": True},
86+
)
87+
break
8088

8189
assert payload is not None
8290

supertokens_python/recipe/session/constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414
from __future__ import annotations
15+
1516
from typing import TYPE_CHECKING, List
1617

1718
if TYPE_CHECKING:
@@ -32,4 +33,5 @@
3233

3334
available_token_transfer_methods: List[TokenTransferMethod] = ["cookie", "header"]
3435

35-
JWKCacheMaxAgeInMs = 60000
36+
JWKCacheMaxAgeInMs = 60 * 1000
37+
JWKRequestCooldownInMs = 500 * 1000

supertokens_python/recipe/session/jwks.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,22 @@
88

99
from supertokens_python.utils import get_timestamp_ms
1010

11-
from .constants import JWKCacheMaxAgeInMs
11+
from .constants import JWKCacheMaxAgeInMs, JWKRequestCooldownInMs
1212

1313

1414
class JWKClient:
1515
def __init__(
1616
self,
1717
uri: str,
18-
cooldown_duration: int = 500,
18+
cooldown_duration: int = JWKRequestCooldownInMs,
1919
cache_max_age: int = JWKCacheMaxAgeInMs,
2020
):
2121
"""A client for retrieving JSON Web Key Sets (JWKS) from a given URI.
2222
2323
Args:
2424
uri (str): The URI of the JWKS.
25-
cooldown_duration (int, optional): The cooldown duration in ms. Defaults to 500.
26-
cache_max_age (int, optional): The cache max age in ms. Defaults to 300.
25+
cooldown_duration (int, optional): The cooldown duration in ms. Defaults to 500 seconds.
26+
cache_max_age (int, optional): The cache max age in ms. Defaults to 5 minutes.
2727
2828
Note: The JSON Web Key Set is fetched when no key matches the selection
2929
process but only as frequently as the `self.cooldown_duration` option
@@ -63,7 +63,8 @@ def get_latest_keys(self) -> List[PyJWK]:
6363
if self.jwk_set is None or not self.is_fresh():
6464
self.reload()
6565

66-
assert self.jwk_set is not None
66+
if self.jwk_set is None:
67+
raise JWKSRequestError("Failed to fetch the latest keys")
6768

6869
all_keys: List[PyJWK] = self.jwk_set.keys # type: ignore
6970

supertokens_python/recipe/session/recipe_implementation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from supertokens_python import AppInfo
5757
from supertokens_python.querier import Querier
5858

59-
from .constants import JWKCacheMaxAgeInMs
59+
from .constants import JWKCacheMaxAgeInMs, JWKRequestCooldownInMs
6060
from .interfaces import SessionContainer
6161

6262
protected_props = [
@@ -80,7 +80,7 @@ def __init__(self, querier: Querier, config: SessionConfig, app_info: AppInfo):
8080
@property
8181
def JWK_clients(self) -> List[JWKClient]:
8282
return [
83-
JWKClient(uri, cooldown_duration=500, cache_max_age=JWKCacheMaxAgeInMs)
83+
JWKClient(uri, cooldown_duration=JWKRequestCooldownInMs, cache_max_age=JWKCacheMaxAgeInMs)
8484
for uri in self.querier.get_all_core_urls_for_path(".well-known/jwks.json")
8585
]
8686

supertokens_python/recipe/session/session_functions.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -145,32 +145,33 @@ async def get_session(
145145

146146
# If we get here we either have a V2 token that doesn't pass verification or a valid V3> token
147147
# anti-csrf check if accesstokenInfo is not undefined which means token verification was successful
148-
149-
if config.anti_csrf == "VIA_TOKEN" and do_anti_csrf_check:
150-
if access_token_info is not None:
151-
if (
152-
anti_csrf_token is None
153-
or anti_csrf_token != access_token_info["antiCsrfToken"]
154-
):
155-
if anti_csrf_token is None:
156-
log_debug_message(
157-
"getSession: Returning TRY_REFRESH_TOKEN because antiCsrfToken is missing from request"
158-
)
159-
raise_try_refresh_token_exception(
160-
"Provided antiCsrfToken is undefined. If you do not want anti-csrf check for this API, please set doAntiCsrfCheck to false for this API"
161-
)
162-
else:
163-
log_debug_message(
164-
"getSession: Returning TRY_REFRESH_TOKEN because the passed antiCsrfToken is not the same as in the access token"
165-
)
166-
raise_try_refresh_token_exception("anti-csrf check failed")
167-
168-
elif config.anti_csrf == "VIA_CUSTOM_HEADER":
169-
# The function should never be called by this (we check this outside the function as well)
170-
# There we can add a bit more information to the error, so that's the primary check, this is just making sure.
171-
raise Exception(
172-
"Please either use VIA_TOKEN, NONE or call with doAntiCsrfCheck false"
173-
)
148+
149+
if do_anti_csrf_check:
150+
if config.anti_csrf == "VIA_TOKEN":
151+
if access_token_info is not None:
152+
if (
153+
anti_csrf_token is None
154+
or anti_csrf_token != access_token_info["antiCsrfToken"]
155+
):
156+
if anti_csrf_token is None:
157+
log_debug_message(
158+
"getSession: Returning TRY_REFRESH_TOKEN because antiCsrfToken is missing from request"
159+
)
160+
raise_try_refresh_token_exception(
161+
"Provided antiCsrfToken is undefined. If you do not want anti-csrf check for this API, please set doAntiCsrfCheck to false for this API"
162+
)
163+
else:
164+
log_debug_message(
165+
"getSession: Returning TRY_REFRESH_TOKEN because the passed antiCsrfToken is not the same as in the access token"
166+
)
167+
raise_try_refresh_token_exception("anti-csrf check failed")
168+
169+
elif config.anti_csrf == "VIA_CUSTOM_HEADER":
170+
# The function should never be called by this (we check this outside the function as well)
171+
# There we can add a bit more information to the error, so that's the primary check, this is just making sure.
172+
raise Exception(
173+
"Please either use VIA_TOKEN, NONE or call with doAntiCsrfCheck false"
174+
)
174175

175176
if (
176177
access_token_info is not None

0 commit comments

Comments
 (0)