Skip to content

test: Fix failing unit tests for access token v3 #321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ set-up-hooks:
chmod +x .git/hooks/pre-commit

test:
pytest --reruns 3 --reruns-delay 5 ./tests/
pytest -vv --reruns 3 --reruns-delay 5 ./tests/

dev-install:
pip install -r dev-requirements.txt
Expand Down
3 changes: 1 addition & 2 deletions supertokens_python/framework/django/django_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@

if TYPE_CHECKING:
from supertokens_python.recipe.session.interfaces import SessionContainer
from django.http import HttpRequest


class DjangoRequest(BaseRequest):
from django.http import HttpRequest

def __init__(self, request: HttpRequest):
super().__init__()
self.request = request
Expand Down
4 changes: 1 addition & 3 deletions supertokens_python/framework/fastapi/fastapi_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@

if TYPE_CHECKING:
from supertokens_python.recipe.session.interfaces import SessionContainer
from fastapi import Request


class FastApiRequest(BaseRequest):

from fastapi import Request

def __init__(self, request: Request):
super().__init__()
self.request = request
Expand Down
3 changes: 1 addition & 2 deletions supertokens_python/framework/flask/flask_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@

if TYPE_CHECKING:
from supertokens_python.recipe.session.interfaces import SessionContainer
from flask.wrappers import Request


class FlaskRequest(BaseRequest):
from flask.wrappers import Request

def __init__(self, req: Request):
super().__init__()
self.request = req
Expand Down
8 changes: 7 additions & 1 deletion supertokens_python/recipe/session/asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,13 @@ async def create_new_session_without_request_response(
claims_added_by_other_recipes = (
SessionRecipe.get_instance().get_claims_added_by_other_recipes()
)
final_access_token_payload = access_token_payload
app_info = SessionRecipe.get_instance().app_info
issuer = (
app_info.api_domain.get_as_string_dangerous()
+ app_info.api_base_path.get_as_string_dangerous()
)

final_access_token_payload = {**access_token_payload, "iss": issuer}

for claim in claims_added_by_other_recipes:
update = await claim.build(user_id, user_context)
Expand Down
13 changes: 8 additions & 5 deletions supertokens_python/recipe/session/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def raise_token_theft_exception(user_id: str, session_handle: str) -> NoReturn:
def raise_try_refresh_token_exception(ex: Union[str, Exception]) -> NoReturn:
if isinstance(ex, SuperTokensError):
raise ex

raise TryRefreshTokenError(ex) from None


Expand All @@ -36,16 +37,18 @@ def raise_unauthorised_exception(
clear_tokens: bool = True,
response_mutators: Optional[List[ResponseMutator]] = None,
) -> NoReturn:
if response_mutators is None:
response_mutators = []

err = UnauthorisedError(msg, clear_tokens)
err.response_mutators.extend(UnauthorisedError.response_mutators)

if response_mutators is not None:
err.response_mutators.extend(response_mutators)

raise err


class SuperTokensSessionError(SuperTokensError):
response_mutators: List[ResponseMutator] = []
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.response_mutators: List[ResponseMutator] = []


class TokenTheftError(SuperTokensSessionError):
Expand Down
10 changes: 6 additions & 4 deletions supertokens_python/recipe/session/jwks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def reload(self):
with urllib.request.urlopen(self.uri, timeout=self.timeout_sec) as response:
self.jwk_set = PyJWKSet.from_dict(json.load(response)) # type: ignore
self.last_fetch_time = get_timestamp_ms()
except URLError as e:
raise JWKSRequestError(f'Failed to fetch data from the url, err: "{e}"')
except URLError:
raise JWKSRequestError("Failed to fetch jwk set from the configured uri")

def is_cooling_down(self) -> bool:
return (self.last_fetch_time > 0) and (
Expand Down Expand Up @@ -81,15 +81,17 @@ def get_matching_key_from_jwt(self, token: str) -> PyJWK:

try:
return self.jwk_set[kid] # type: ignore
except IndexError:
except KeyError:
if not self.is_cooling_down():
# One more attempt to fetch the latest keys
# and then try to find the key again.
self.reload()
try:
return self.jwk_set[kid] # type: ignore
except IndexError:
except KeyError:
pass
except Exception:
raise JWKSKeyNotFoundError("No key found for the given kid")

raise JWKSKeyNotFoundError("No key found for the given kid")

Expand Down
19 changes: 11 additions & 8 deletions supertokens_python/recipe/session/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,6 @@ def __init__(
expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None,
):
super().__init__(recipe_id, app_info)
self.openid_recipe = OpenIdRecipe(
recipe_id,
app_info,
None,
None,
override.openid_feature if override is not None else None,
)
self.config = validate_and_normalise_user_input(
app_info,
cookie_domain,
Expand All @@ -117,6 +110,13 @@ def __init__(
use_dynamic_access_token_signing_key,
expose_access_token_to_frontend_in_cookie_based_auth,
)
self.openid_recipe = OpenIdRecipe(
recipe_id,
app_info,
None,
None,
override.openid_feature if override is not None else None,
)
log_debug_message("session init: anti_csrf: %s", self.config.anti_csrf)
if self.config.cookie_domain is not None:
log_debug_message(
Expand Down Expand Up @@ -221,7 +221,10 @@ async def handle_api_request(
async def handle_error(
self, request: BaseRequest, err: SuperTokensError, response: BaseResponse
) -> BaseResponse:
if isinstance(err, SuperTokensSessionError):
if (
isinstance(err, SuperTokensSessionError)
and err.response_mutators is not None
):
for mutator in err.response_mutators:
mutator(response)

Expand Down
4 changes: 3 additions & 1 deletion supertokens_python/recipe/session/session_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ async def attach_to_request_response(
anti_csrf_response_mutator(self.anti_csrf_token)
)

request.set_session(self)
request.set_session(
self
) # Although this is called in recipe/session/framework/**/__init__.py. It's required in case of python because functions like create_new_session(req, "user-id") can be called in the framework view handler as well

async def revoke_session(self, user_context: Union[Any, None] = None) -> None:
if user_context is None:
Expand Down
19 changes: 16 additions & 3 deletions supertokens_python/recipe/session/session_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ async def get_session(
if time_created <= time.time() - JWKCacheMaxAgeInMs:
raise e
else:
# Since v3 (and above) tokens contain a kid we can trust the cache-refresh mechanism of the pyjwt library
# Since v3 (and above) tokens contain a kid we can trust the cache refresh mechanism built on top of the pyjwt lib
# This means we do not need to call the core since the signature wouldn't pass verification anyway.
raise e

Expand Down Expand Up @@ -281,13 +281,26 @@ async def get_session(
response["session"]["handle"],
response["session"]["userId"],
response["session"]["userDataInJWT"],
response["accessToken"]["expiry"],
(
response.get("accessToken", {}).get(
"expiry"
) # if we got a new accesstoken we take the expiry time from there
or (
access_token_info is not None
and access_token_info.get("expiryTime")
) # if we didn't get a new access token but could validate the token take that info (alwaysCheckCore === true, or parentRefreshTokenHash1 !== null)
or parsed_access_token.payload[
"expiryTime"
] # if the token didn't pass validation, but we got here, it means it was a v2 token that we didn't have the key cached for.
), # This will throw error if others are none and 'expiryTime' key doesn't exist in the payload
),
GetSessionAPIResponseAccessToken(
response["accessToken"]["token"],
response["accessToken"]["expiry"],
response["accessToken"]["createdTime"],
),
)
if "accessToken" in response
else None,
)
if response["status"] == "UNAUTHORISED":
log_debug_message("getSession: Returning UNAUTHORISED because of core response")
Expand Down
10 changes: 8 additions & 2 deletions supertokens_python/recipe/session/session_request_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,13 @@ async def create_new_session_in_request(
user_context = set_request_in_user_context_if_not_defined(user_context, request)

claims_added_by_other_recipes = recipe_instance.get_claims_added_by_other_recipes()
final_access_token_payload = access_token_payload
app_info = recipe_instance.app_info
issuer = (
app_info.api_domain.get_as_string_dangerous()
+ app_info.api_base_path.get_as_string_dangerous()
)

final_access_token_payload = {**access_token_payload, "iss": issuer}

for claim in claims_added_by_other_recipes:
update = await claim.build(user_id, user_context)
Expand Down Expand Up @@ -264,7 +270,7 @@ async def create_new_session_in_request(
# We can allow insecure cookie when both website & API domain are localhost or an IP
# When either of them is a different domain, API domain needs to have https and a secure cookie to work
raise Exception(
"Since your API and website domain are different, for sessions to work, please use https on your apiDomain and dont set cookieSecure to false."
"Since your API and website domain are different, for sessions to work, please use https on your apiDomain and don't set cookieSecure to false."
)

disable_anti_csrf = output_transfer_method == "header"
Expand Down
5 changes: 2 additions & 3 deletions supertokens_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,8 @@ def find_max_version(versions_1: List[str], versions_2: List[str]) -> Union[str,
return max_v


def is_version_gte(version: str, minimum_minor_version: str) -> bool:
assert len(minimum_minor_version.split(".")) == 2
return _get_max_version(version, minimum_minor_version) == version
def is_version_gte(version: str, minimum_version: str) -> bool:
return _get_max_version(version, minimum_version) == version


def _get_max_version(v1: str, v2: str) -> str:
Expand Down
3 changes: 0 additions & 3 deletions tests/Fastapi/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,9 +620,6 @@ async def refresh_post(
assert res.status_code == 401
assert_info_clears_tokens(info, token_transfer_method)

assert info["sIdRefreshToken"]["value"] == ""
assert info["sIdRefreshToken"]["expires"] == "Thu, 01 Jan 1970 00:00:00 GMT"


@mark.asyncio
@mark.parametrize("token_transfer_method", ["cookie", "header"])
Expand Down
2 changes: 2 additions & 0 deletions tests/frontendIntegration/django2x/polls/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
path("", views.get_info, name="/"), # type: ignore
path("update-jwt", views.update_jwt, name="update_jwt"), # type: ignore
path("update-jwt-with-handle", views.update_jwt_with_handle, name="update_jwt_with_handle"), # type: ignore
path("session-claims-error", views.session_claim_error_api, name="session_claim_error_api"), # type: ignore
path("403-without-body", views.without_body_403, name="without_body_403"), # type: ignore
path("testing", views.testing, name="testing"), # type: ignore
path("logout", views.logout, name="logout"), # type: ignore
path("revokeAll", views.revoke_all, name="revokeAll"), # type: ignore
Expand Down
Loading