Skip to content

fix: Multitenancy tests for all the recipes #396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions supertokens_python/recipe/emailpassword/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def default_validator(_: str, __: str) -> Union[str, None]:
return None


async def default_password_validator(_: str, value: str) -> Union[str, None]:
async def default_password_validator(value: str, _tenant_id: str) -> Union[str, None]:
# length >= 8 && < 100
# must have a number and a character
# as per
Expand All @@ -62,7 +62,7 @@ async def default_password_validator(_: str, value: str) -> Union[str, None]:
return None


async def default_email_validator(_: str, value: Any) -> Union[str, None]:
async def default_email_validator(value: Any, _tenant_id: str) -> Union[str, None]:
# We check if the email syntax is correct
# As per https://github.com/supertokens/supertokens-auth-react/issues/5#issuecomment-709512438
# Regex from https://stackoverflow.com/a/46181/3867175
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ async def get_user_by_phone_number(
) -> Union[User, None]:
param = {"phoneNumber": phone_number}
result = await self.querier.send_get_request(
NormalisedURLPath("{tenant_id}/recipe/user"), param
NormalisedURLPath(f"{tenant_id}/recipe/user"), param
)
if result["status"] == "OK":
email_resp = None
Expand Down
1 change: 1 addition & 0 deletions supertokens_python/recipe/thirdparty/api/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from supertokens_python.utils import utf_base64decode
from base64 import b64decode
import json

Expand Down
3 changes: 2 additions & 1 deletion supertokens_python/recipe/thirdparty/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def __init__(
force_pkce: bool = False,
additional_config: Optional[Dict[str, Any]] = None,
# CommonProviderConfig:
third_party_id: str = "temp",
name: Optional[str] = None,
authorization_endpoint: Optional[str] = None,
authorization_endpoint_query_params: Optional[
Expand Down Expand Up @@ -261,7 +262,7 @@ def __init__(
)
CommonProviderConfig.__init__(
self,
"temp",
third_party_id,
name,
authorization_endpoint,
authorization_endpoint_query_params,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def merge_providers_from_core_and_static(
merged_provider_input.override = provider_input_from_static.override
break

merged_providers.append(merged_provider_input)

return merged_providers


Expand Down
4 changes: 3 additions & 1 deletion supertokens_python/recipe/thirdparty/providers/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def get_provider_config_for_client(
force_pkce=client_config.force_pkce,
additional_config=client_config.additional_config,
# CommonProviderConfig
third_party_id=config.third_party_id,
name=config.name,
authorization_endpoint=config.authorization_endpoint,
authorization_endpoint_query_params=config.authorization_endpoint_query_params,
Expand Down Expand Up @@ -179,7 +180,8 @@ def __init__(self, provider_config: ProviderConfig):
self.input_config = input_config = self._normalize_input(provider_config)

provider_config_for_client = ProviderConfigForClient(
# Will automatically get replaced by correct value
# Will automatically get replaced with correct value
# in get_provider_config_for_client
# when fetch_and_set_config function runs
client_id="temp",
client_secret=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def get_users_by_email(
user["id"],
user["email"],
user["timeJoined"],
response["user"]["tenantIds"],
user["tenantIds"],
ThirdPartyInfo(
user["thirdParty"]["userId"], user["thirdParty"]["id"]
),
Expand Down
4 changes: 2 additions & 2 deletions supertokens_python/recipe/userroles/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ async def fetch_value(
recipe = UserRolesRecipe.get_instance()

user_roles = await recipe.recipe_implementation.get_roles_for_user(
tenant_id, user_id, user_context
user_id, tenant_id, user_context
)

user_permissions: Set[str] = set()
Expand Down Expand Up @@ -186,7 +186,7 @@ async def fetch_value(
) -> List[str]:
recipe = UserRolesRecipe.get_instance()
res = await recipe.recipe_implementation.get_roles_for_user(
tenant_id, user_id, user_context
user_id, tenant_id, user_context
)
return res.roles

Expand Down
17 changes: 8 additions & 9 deletions tests/Django/test_django.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from urllib.parse import urlencode
from datetime import datetime
from inspect import isawaitable
from base64 import b64encode
from typing import Any, Dict, Union

from django.http import HttpRequest, HttpResponse, JsonResponse
Expand Down Expand Up @@ -455,10 +456,10 @@ async def test_thirdparty_parsing_works(self):

start_st()

data = {
"state": "afc596274293e1587315c",
"code": "c7685e261f98e4b3b94e34b3a69ff9cf4.0.rvxt.eE8rO__6hGoqaX1B7ODPmA",
}
state = b64encode(json.dumps({"redirectURI": "http://localhost:3000/redirect" }).encode()).decode()
code = "testing"

data = { "state": state, "code": code}

request = self.factory.post(
"/auth/callback/apple",
Expand All @@ -470,11 +471,9 @@ async def test_thirdparty_parsing_works(self):
raise Exception("Should never come here")
response = await temp

self.assertEqual(response.status_code, 200)
self.assertEqual(
response.content,
b'<html><head><script>window.location.replace("http://supertokens.io/auth/callback/apple?state=afc596274293e1587315c&code=c7685e261f98e4b3b94e34b3a69ff9cf4.0.rvxt.eE8rO__6hGoqaX1B7ODPmA");</script></head></html>',
)
self.assertEqual(response.status_code, 303)
self.assertEqual(response.content, b'')
self.assertEqual(response.headers['location'], f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}")

@pytest.mark.asyncio
async def test_search_with_multiple_emails(self):
Expand Down
19 changes: 9 additions & 10 deletions tests/Flask/test_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import json
from typing import Any, Dict, Union
from base64 import b64encode

import pytest
from _pytest.fixtures import fixture
Expand Down Expand Up @@ -477,17 +478,15 @@ def test_thirdparty_parsing_works(driver_config_app: Any):
start_st()

test_client = driver_config_app.test_client()
data = {
"state": "afc596274293e1587315c",
"code": "c7685e261f98e4b3b94e34b3a69ff9cf4.0.rvxt.eE8rO__6hGoqaX1B7ODPmA",
}
response = test_client.post("/auth/callback/apple", data=data)
state = b64encode(json.dumps({"redirectURI": "http://localhost:3000/redirect" }).encode()).decode()
code = "testing"

assert response.status_code == 200
assert (
response.data
== b'<html><head><script>window.location.replace("http://supertokens.io/auth/callback/apple?state=afc596274293e1587315c&code=c7685e261f98e4b3b94e34b3a69ff9cf4.0.rvxt.eE8rO__6hGoqaX1B7ODPmA");</script></head></html>'
)
data = { "state": state, "code": code}
res = test_client.post("/auth/callback/apple", data=data)

assert res.status_code == 303
assert res.data == b''
assert res.headers["location"] == f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}"


from flask.wrappers import Response
Expand Down
23 changes: 16 additions & 7 deletions tests/emailpassword/test_multitenancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
from supertokens_python.recipe.multitenancy.interfaces import TenantConfig

from tests.utils import get_st_init_args
from tests.utils import setup_function, teardown_function, setup_multitenancy_feature
from tests.utils import (
setup_function,
teardown_function,
setup_multitenancy_feature,
start_st,
)


_ = setup_function
Expand All @@ -42,7 +47,7 @@
pytestmark = mark.asyncio


async def test_multitenancy_in_user_roles():
async def test_multitenancy_in_emailpassword():
# test that different roles can be assigned for the same user for each tenant
args = get_st_init_args(
[
Expand All @@ -53,6 +58,8 @@ async def test_multitenancy_in_user_roles():
]
)
init(**args) # type: ignore
start_st()

setup_multitenancy_feature()

await create_or_update_tenant("t1", TenantConfig(email_password_enabled=True))
Expand All @@ -77,14 +84,16 @@ async def test_multitenancy_in_user_roles():

# sign in
ep_user1 = await sign_in("[email protected]", "password1", "t1")
ep_user2 = await sign_in("[email protected]", "password1", "t2")
ep_user3 = await sign_in("[email protected]", "password1", "t3")
ep_user2 = await sign_in("[email protected]", "password2", "t2")
ep_user3 = await sign_in("[email protected]", "password3", "t3")

assert isinstance(ep_user1, SignInOkResult)
assert isinstance(ep_user2, SignInOkResult)
assert isinstance(ep_user3, SignInOkResult)

assert ep_user1.user.user_id == user2.user.user_id == user3.user.user_id
assert ep_user1.user.user_id == user1.user.user_id
assert ep_user2.user.user_id == user2.user.user_id
assert ep_user3.user.user_id == user3.user.user_id

# get user by id:
g_user1 = await get_user_by_id(user1.user.user_id)
Expand All @@ -106,8 +115,8 @@ async def test_multitenancy_in_user_roles():

# create password reset token:
pless_reset_link1 = await create_reset_password_token(user1.user.user_id, "t1")
pless_reset_link2 = await create_reset_password_token(user1.user.user_id, "t2")
pless_reset_link3 = await create_reset_password_token(user1.user.user_id, "t3")
pless_reset_link2 = await create_reset_password_token(user2.user.user_id, "t2")
pless_reset_link3 = await create_reset_password_token(user3.user.user_id, "t3")

assert isinstance(pless_reset_link1, CreateResetPasswordOkResult)
assert isinstance(pless_reset_link2, CreateResetPasswordOkResult)
Expand Down
8 changes: 7 additions & 1 deletion tests/passwordless/test_mutlitenancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
from supertokens_python.recipe.multitenancy.interfaces import TenantConfig

from tests.utils import get_st_init_args
from tests.utils import setup_function, teardown_function, setup_multitenancy_feature
from tests.utils import (
setup_function,
teardown_function,
setup_multitenancy_feature,
start_st,
)


_ = setup_function
Expand All @@ -49,6 +54,7 @@ async def test_multitenancy_functions():
]
)
init(**args)
start_st()
setup_multitenancy_feature()

await create_or_update_tenant("t1", TenantConfig(passwordless_enabled=True))
Expand Down
2 changes: 1 addition & 1 deletion tests/sessions/claims/test_create_new_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,6 @@ async def test_should_merge_claims_and_passed_access_token_payload_obj(timestamp
s = await create_new_session(dummy_req, "someId")

payload = s.get_access_token_payload()
assert len(payload) == 10
assert len(payload) == 11
assert payload["st-true"] == {"v": True, "t": timestamp}
assert payload["user-custom-claim"] == "foo"
2 changes: 1 addition & 1 deletion tests/sessions/claims/test_get_claim_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ async def test_should_work_for_non_existing_handle():
init(**new_st_init) # type: ignore
start_st()

res = await get_claim_value("non_existing_handle", TrueClaim)
res = await get_claim_value("non-existing-handle", TrueClaim)
assert isinstance(res, SessionDoesNotExistError)
2 changes: 1 addition & 1 deletion tests/sessions/claims/test_primitive_array_claim.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def test_primitive_claim_fetch_value_params_correct():
user_id, ctx = "user_id", {}
await claim.build(user_id, DEFAULT_TENANT_ID, ctx)
assert sync_fetch_value.call_count == 1
assert (user_id, ctx) == sync_fetch_value.call_args_list[0][
assert (user_id, DEFAULT_TENANT_ID, ctx) == sync_fetch_value.call_args_list[0][
0
] # extra [0] refers to call params

Expand Down
2 changes: 1 addition & 1 deletion tests/sessions/claims/test_primitive_claim.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def test_primitive_claim_fetch_value_params_correct():
user_id, ctx = "user_id", {}
await claim.build(user_id, DEFAULT_TENANT_ID, ctx)
assert sync_fetch_value.call_count == 1
assert (user_id, ctx) == sync_fetch_value.call_args_list[0][
assert (user_id, DEFAULT_TENANT_ID, ctx) == sync_fetch_value.call_args_list[0][
0
] # extra [0] refers to call params

Expand Down
6 changes: 3 additions & 3 deletions tests/sessions/claims/test_set_claim_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ async def test_should_overwrite_claim_value(timestamp: int):
s = await create_new_session(dummy_req, "someId")

payload = s.get_access_token_payload()
assert len(payload) == 9
assert len(payload) == 10
assert payload["st-true"] == {"t": timestamp, "v": True}

await s.set_claim_value(TrueClaim, False)

# Payload should be updated now:
payload = s.get_access_token_payload()
assert len(payload) == 9
assert len(payload) == 10
assert payload["st-true"] == {"t": timestamp, "v": False}


Expand All @@ -79,7 +79,7 @@ async def test_should_overwrite_claim_value_using_session_handle(timestamp: int)
s = await create_new_session(dummy_req, "someId")

payload = s.get_access_token_payload()
assert len(payload) == 9
assert len(payload) == 10
assert payload["st-true"] == {"t": timestamp, "v": True}

await set_claim_value(s.get_handle(), TrueClaim, False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ async def test_should_work_for_not_existing_handle():
start_st()

res = await validate_claims_for_session_handle(
"non_existing_handle", lambda _, __, ___: []
"non-existing-handle", lambda _, __, ___: []
)
assert isinstance(res, SessionDoesNotExistError)
4 changes: 2 additions & 2 deletions tests/sessions/claims/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from supertokens_python.recipe.session.interfaces import RecipeInterface
from tests.utils import st_init_common_args

TrueClaim = BooleanClaim("st-true", fetch_value=lambda _, __: True) # type: ignore
NoneClaim = BooleanClaim("st-none", fetch_value=lambda _, __: None) # type: ignore
TrueClaim = BooleanClaim("st-true", fetch_value=lambda _, __, ___: True) # type: ignore
NoneClaim = BooleanClaim("st-none", fetch_value=lambda _, __, ___: None) # type: ignore


def session_functions_override_with_claim(
Expand Down
13 changes: 8 additions & 5 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def test_creating_many_sessions_for_one_user_and_looping():

assert len(session_handles) == 7

for i, handle in enumerate(session_handles):
for handle in session_handles:
info = await get_session_information(handle)
assert info is not None
assert info.user_id == "someUser"
Expand All @@ -224,19 +224,22 @@ async def test_creating_many_sessions_for_one_user_and_looping():
assert info.custom_claims_in_access_token_payload == {"someKey2": "someValue"}
assert info.session_data_in_database == {"foo": "bar"}

regenerated_session_handles: List[str] = []
# Regenerate access token with new access_token_payload
for i, token in enumerate(access_tokens):
for token in access_tokens:
result = await regenerate_access_token(token, {"bar": "baz"})
assert result is not None
assert (
result.session.handle == session_handles[i]
) # Session handle should remain the same
regenerated_session_handles.append(result.session.handle)

# Confirm that update worked:
info = await get_session_information(result.session.handle)
assert info is not None
assert info.custom_claims_in_access_token_payload == {"bar": "baz"}

# Session handle should remain the same session handle should remain the same
# but order isn't guaranteed so we should sort them
assert sorted(regenerated_session_handles) == sorted(session_handles)

# Try updating invalid handles:
is_updated = await merge_into_access_token_payload("invalidHandle", {"foo": "bar"})
assert is_updated is False
Expand Down
Loading