Skip to content

Commit bd91b74

Browse files
Merge pull request #396 from supertokens/fix/multitenancy-recipe-tests
fix: Multitenancy tests for all the recipes
2 parents 6083a1b + 3acd4e7 commit bd91b74

23 files changed

+130
-65
lines changed

supertokens_python/recipe/emailpassword/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ async def default_validator(_: str, __: str) -> Union[str, None]:
4242
return None
4343

4444

45-
async def default_password_validator(_: str, value: str) -> Union[str, None]:
45+
async def default_password_validator(value: str, _tenant_id: str) -> Union[str, None]:
4646
# length >= 8 && < 100
4747
# must have a number and a character
4848
# as per
@@ -62,7 +62,7 @@ async def default_password_validator(_: str, value: str) -> Union[str, None]:
6262
return None
6363

6464

65-
async def default_email_validator(_: str, value: Any) -> Union[str, None]:
65+
async def default_email_validator(value: Any, _tenant_id: str) -> Union[str, None]:
6666
# We check if the email syntax is correct
6767
# As per https://github.com/supertokens/supertokens-auth-react/issues/5#issuecomment-709512438
6868
# Regex from https://stackoverflow.com/a/46181/3867175

supertokens_python/recipe/passwordless/recipe_implementation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ async def get_user_by_phone_number(
206206
) -> Union[User, None]:
207207
param = {"phoneNumber": phone_number}
208208
result = await self.querier.send_get_request(
209-
NormalisedURLPath("{tenant_id}/recipe/user"), param
209+
NormalisedURLPath(f"{tenant_id}/recipe/user"), param
210210
)
211211
if result["status"] == "OK":
212212
email_resp = None

supertokens_python/recipe/thirdparty/api/implementation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414
from __future__ import annotations
15+
from supertokens_python.utils import utf_base64decode
1516
from base64 import b64decode
1617
import json
1718

supertokens_python/recipe/thirdparty/provider.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def __init__(
226226
force_pkce: bool = False,
227227
additional_config: Optional[Dict[str, Any]] = None,
228228
# CommonProviderConfig:
229+
third_party_id: str = "temp",
229230
name: Optional[str] = None,
230231
authorization_endpoint: Optional[str] = None,
231232
authorization_endpoint_query_params: Optional[
@@ -261,7 +262,7 @@ def __init__(
261262
)
262263
CommonProviderConfig.__init__(
263264
self,
264-
"temp",
265+
third_party_id,
265266
name,
266267
authorization_endpoint,
267268
authorization_endpoint_query_params,

supertokens_python/recipe/thirdparty/providers/config_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ def merge_providers_from_core_and_static(
166166
merged_provider_input.override = provider_input_from_static.override
167167
break
168168

169+
merged_providers.append(merged_provider_input)
170+
169171
return merged_providers
170172

171173

supertokens_python/recipe/thirdparty/providers/custom.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def get_provider_config_for_client(
4343
force_pkce=client_config.force_pkce,
4444
additional_config=client_config.additional_config,
4545
# CommonProviderConfig
46+
third_party_id=config.third_party_id,
4647
name=config.name,
4748
authorization_endpoint=config.authorization_endpoint,
4849
authorization_endpoint_query_params=config.authorization_endpoint_query_params,
@@ -179,7 +180,8 @@ def __init__(self, provider_config: ProviderConfig):
179180
self.input_config = input_config = self._normalize_input(provider_config)
180181

181182
provider_config_for_client = ProviderConfigForClient(
182-
# Will automatically get replaced by correct value
183+
# Will automatically get replaced with correct value
184+
# in get_provider_config_for_client
183185
# when fetch_and_set_config function runs
184186
client_id="temp",
185187
client_secret=None,

supertokens_python/recipe/thirdparty/recipe_implementation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ async def get_users_by_email(
7777
user["id"],
7878
user["email"],
7979
user["timeJoined"],
80-
response["user"]["tenantIds"],
80+
user["tenantIds"],
8181
ThirdPartyInfo(
8282
user["thirdParty"]["userId"], user["thirdParty"]["id"]
8383
),

supertokens_python/recipe/userroles/recipe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ async def fetch_value(
152152
recipe = UserRolesRecipe.get_instance()
153153

154154
user_roles = await recipe.recipe_implementation.get_roles_for_user(
155-
tenant_id, user_id, user_context
155+
user_id, tenant_id, user_context
156156
)
157157

158158
user_permissions: Set[str] = set()
@@ -186,7 +186,7 @@ async def fetch_value(
186186
) -> List[str]:
187187
recipe = UserRolesRecipe.get_instance()
188188
res = await recipe.recipe_implementation.get_roles_for_user(
189-
tenant_id, user_id, user_context
189+
user_id, tenant_id, user_context
190190
)
191191
return res.roles
192192

tests/Django/test_django.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from urllib.parse import urlencode
1717
from datetime import datetime
1818
from inspect import isawaitable
19+
from base64 import b64encode
1920
from typing import Any, Dict, Union
2021

2122
from django.http import HttpRequest, HttpResponse, JsonResponse
@@ -455,10 +456,10 @@ async def test_thirdparty_parsing_works(self):
455456

456457
start_st()
457458

458-
data = {
459-
"state": "afc596274293e1587315c",
460-
"code": "c7685e261f98e4b3b94e34b3a69ff9cf4.0.rvxt.eE8rO__6hGoqaX1B7ODPmA",
461-
}
459+
state = b64encode(json.dumps({"redirectURI": "http://localhost:3000/redirect" }).encode()).decode()
460+
code = "testing"
461+
462+
data = { "state": state, "code": code}
462463

463464
request = self.factory.post(
464465
"/auth/callback/apple",
@@ -470,11 +471,9 @@ async def test_thirdparty_parsing_works(self):
470471
raise Exception("Should never come here")
471472
response = await temp
472473

473-
self.assertEqual(response.status_code, 200)
474-
self.assertEqual(
475-
response.content,
476-
b'<html><head><script>window.location.replace("http://supertokens.io/auth/callback/apple?state=afc596274293e1587315c&code=c7685e261f98e4b3b94e34b3a69ff9cf4.0.rvxt.eE8rO__6hGoqaX1B7ODPmA");</script></head></html>',
477-
)
474+
self.assertEqual(response.status_code, 303)
475+
self.assertEqual(response.content, b'')
476+
self.assertEqual(response.headers['location'], f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}")
478477

479478
@pytest.mark.asyncio
480479
async def test_search_with_multiple_emails(self):

tests/Flask/test_flask.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import json
1616
from typing import Any, Dict, Union
17+
from base64 import b64encode
1718

1819
import pytest
1920
from _pytest.fixtures import fixture
@@ -477,17 +478,15 @@ def test_thirdparty_parsing_works(driver_config_app: Any):
477478
start_st()
478479

479480
test_client = driver_config_app.test_client()
480-
data = {
481-
"state": "afc596274293e1587315c",
482-
"code": "c7685e261f98e4b3b94e34b3a69ff9cf4.0.rvxt.eE8rO__6hGoqaX1B7ODPmA",
483-
}
484-
response = test_client.post("/auth/callback/apple", data=data)
481+
state = b64encode(json.dumps({"redirectURI": "http://localhost:3000/redirect" }).encode()).decode()
482+
code = "testing"
485483

486-
assert response.status_code == 200
487-
assert (
488-
response.data
489-
== b'<html><head><script>window.location.replace("http://supertokens.io/auth/callback/apple?state=afc596274293e1587315c&code=c7685e261f98e4b3b94e34b3a69ff9cf4.0.rvxt.eE8rO__6hGoqaX1B7ODPmA");</script></head></html>'
490-
)
484+
data = { "state": state, "code": code}
485+
res = test_client.post("/auth/callback/apple", data=data)
486+
487+
assert res.status_code == 303
488+
assert res.data == b''
489+
assert res.headers["location"] == f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}"
491490

492491

493492
from flask.wrappers import Response

tests/emailpassword/test_multitenancy.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@
3333
from supertokens_python.recipe.multitenancy.interfaces import TenantConfig
3434

3535
from tests.utils import get_st_init_args
36-
from tests.utils import setup_function, teardown_function, setup_multitenancy_feature
36+
from tests.utils import (
37+
setup_function,
38+
teardown_function,
39+
setup_multitenancy_feature,
40+
start_st,
41+
)
3742

3843

3944
_ = setup_function
@@ -42,7 +47,7 @@
4247
pytestmark = mark.asyncio
4348

4449

45-
async def test_multitenancy_in_user_roles():
50+
async def test_multitenancy_in_emailpassword():
4651
# test that different roles can be assigned for the same user for each tenant
4752
args = get_st_init_args(
4853
[
@@ -53,6 +58,8 @@ async def test_multitenancy_in_user_roles():
5358
]
5459
)
5560
init(**args) # type: ignore
61+
start_st()
62+
5663
setup_multitenancy_feature()
5764

5865
await create_or_update_tenant("t1", TenantConfig(email_password_enabled=True))
@@ -77,14 +84,16 @@ async def test_multitenancy_in_user_roles():
7784

7885
# sign in
7986
ep_user1 = await sign_in("[email protected]", "password1", "t1")
80-
ep_user2 = await sign_in("[email protected]", "password1", "t2")
81-
ep_user3 = await sign_in("[email protected]", "password1", "t3")
87+
ep_user2 = await sign_in("[email protected]", "password2", "t2")
88+
ep_user3 = await sign_in("[email protected]", "password3", "t3")
8289

8390
assert isinstance(ep_user1, SignInOkResult)
8491
assert isinstance(ep_user2, SignInOkResult)
8592
assert isinstance(ep_user3, SignInOkResult)
8693

87-
assert ep_user1.user.user_id == user2.user.user_id == user3.user.user_id
94+
assert ep_user1.user.user_id == user1.user.user_id
95+
assert ep_user2.user.user_id == user2.user.user_id
96+
assert ep_user3.user.user_id == user3.user.user_id
8897

8998
# get user by id:
9099
g_user1 = await get_user_by_id(user1.user.user_id)
@@ -106,8 +115,8 @@ async def test_multitenancy_in_user_roles():
106115

107116
# create password reset token:
108117
pless_reset_link1 = await create_reset_password_token(user1.user.user_id, "t1")
109-
pless_reset_link2 = await create_reset_password_token(user1.user.user_id, "t2")
110-
pless_reset_link3 = await create_reset_password_token(user1.user.user_id, "t3")
118+
pless_reset_link2 = await create_reset_password_token(user2.user.user_id, "t2")
119+
pless_reset_link3 = await create_reset_password_token(user3.user.user_id, "t3")
111120

112121
assert isinstance(pless_reset_link1, CreateResetPasswordOkResult)
113122
assert isinstance(pless_reset_link2, CreateResetPasswordOkResult)

tests/passwordless/test_mutlitenancy.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@
2727
from supertokens_python.recipe.multitenancy.interfaces import TenantConfig
2828

2929
from tests.utils import get_st_init_args
30-
from tests.utils import setup_function, teardown_function, setup_multitenancy_feature
30+
from tests.utils import (
31+
setup_function,
32+
teardown_function,
33+
setup_multitenancy_feature,
34+
start_st,
35+
)
3136

3237

3338
_ = setup_function
@@ -49,6 +54,7 @@ async def test_multitenancy_functions():
4954
]
5055
)
5156
init(**args)
57+
start_st()
5258
setup_multitenancy_feature()
5359

5460
await create_or_update_tenant("t1", TenantConfig(passwordless_enabled=True))

tests/sessions/claims/test_create_new_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,6 @@ async def test_should_merge_claims_and_passed_access_token_payload_obj(timestamp
6969
s = await create_new_session(dummy_req, "someId")
7070

7171
payload = s.get_access_token_payload()
72-
assert len(payload) == 10
72+
assert len(payload) == 11
7373
assert payload["st-true"] == {"v": True, "t": timestamp}
7474
assert payload["user-custom-claim"] == "foo"

tests/sessions/claims/test_get_claim_value.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,5 @@ async def test_should_work_for_non_existing_handle():
5656
init(**new_st_init) # type: ignore
5757
start_st()
5858

59-
res = await get_claim_value("non_existing_handle", TrueClaim)
59+
res = await get_claim_value("non-existing-handle", TrueClaim)
6060
assert isinstance(res, SessionDoesNotExistError)

tests/sessions/claims/test_primitive_array_claim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ async def test_primitive_claim_fetch_value_params_correct():
8181
user_id, ctx = "user_id", {}
8282
await claim.build(user_id, DEFAULT_TENANT_ID, ctx)
8383
assert sync_fetch_value.call_count == 1
84-
assert (user_id, ctx) == sync_fetch_value.call_args_list[0][
84+
assert (user_id, DEFAULT_TENANT_ID, ctx) == sync_fetch_value.call_args_list[0][
8585
0
8686
] # extra [0] refers to call params
8787

tests/sessions/claims/test_primitive_claim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ async def test_primitive_claim_fetch_value_params_correct():
4848
user_id, ctx = "user_id", {}
4949
await claim.build(user_id, DEFAULT_TENANT_ID, ctx)
5050
assert sync_fetch_value.call_count == 1
51-
assert (user_id, ctx) == sync_fetch_value.call_args_list[0][
51+
assert (user_id, DEFAULT_TENANT_ID, ctx) == sync_fetch_value.call_args_list[0][
5252
0
5353
] # extra [0] refers to call params
5454

tests/sessions/claims/test_set_claim_value.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ async def test_should_overwrite_claim_value(timestamp: int):
6060
s = await create_new_session(dummy_req, "someId")
6161

6262
payload = s.get_access_token_payload()
63-
assert len(payload) == 9
63+
assert len(payload) == 10
6464
assert payload["st-true"] == {"t": timestamp, "v": True}
6565

6666
await s.set_claim_value(TrueClaim, False)
6767

6868
# Payload should be updated now:
6969
payload = s.get_access_token_payload()
70-
assert len(payload) == 9
70+
assert len(payload) == 10
7171
assert payload["st-true"] == {"t": timestamp, "v": False}
7272

7373

@@ -79,7 +79,7 @@ async def test_should_overwrite_claim_value_using_session_handle(timestamp: int)
7979
s = await create_new_session(dummy_req, "someId")
8080

8181
payload = s.get_access_token_payload()
82-
assert len(payload) == 9
82+
assert len(payload) == 10
8383
assert payload["st-true"] == {"t": timestamp, "v": True}
8484

8585
await set_claim_value(s.get_handle(), TrueClaim, False)

tests/sessions/claims/test_validate_claims_for_session_handle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,6 @@ async def test_should_work_for_not_existing_handle():
5959
start_st()
6060

6161
res = await validate_claims_for_session_handle(
62-
"non_existing_handle", lambda _, __, ___: []
62+
"non-existing-handle", lambda _, __, ___: []
6363
)
6464
assert isinstance(res, SessionDoesNotExistError)

tests/sessions/claims/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from supertokens_python.recipe.session.interfaces import RecipeInterface
99
from tests.utils import st_init_common_args
1010

11-
TrueClaim = BooleanClaim("st-true", fetch_value=lambda _, __: True) # type: ignore
12-
NoneClaim = BooleanClaim("st-none", fetch_value=lambda _, __: None) # type: ignore
11+
TrueClaim = BooleanClaim("st-true", fetch_value=lambda _, __, ___: True) # type: ignore
12+
NoneClaim = BooleanClaim("st-none", fetch_value=lambda _, __, ___: None) # type: ignore
1313

1414

1515
def session_functions_override_with_claim(

tests/test_session.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ async def test_creating_many_sessions_for_one_user_and_looping():
202202

203203
assert len(session_handles) == 7
204204

205-
for i, handle in enumerate(session_handles):
205+
for handle in session_handles:
206206
info = await get_session_information(handle)
207207
assert info is not None
208208
assert info.user_id == "someUser"
@@ -224,19 +224,22 @@ async def test_creating_many_sessions_for_one_user_and_looping():
224224
assert info.custom_claims_in_access_token_payload == {"someKey2": "someValue"}
225225
assert info.session_data_in_database == {"foo": "bar"}
226226

227+
regenerated_session_handles: List[str] = []
227228
# Regenerate access token with new access_token_payload
228-
for i, token in enumerate(access_tokens):
229+
for token in access_tokens:
229230
result = await regenerate_access_token(token, {"bar": "baz"})
230231
assert result is not None
231-
assert (
232-
result.session.handle == session_handles[i]
233-
) # Session handle should remain the same
232+
regenerated_session_handles.append(result.session.handle)
234233

235234
# Confirm that update worked:
236235
info = await get_session_information(result.session.handle)
237236
assert info is not None
238237
assert info.custom_claims_in_access_token_payload == {"bar": "baz"}
239238

239+
# Session handle should remain the same session handle should remain the same
240+
# but order isn't guaranteed so we should sort them
241+
assert sorted(regenerated_session_handles) == sorted(session_handles)
242+
240243
# Try updating invalid handles:
241244
is_updated = await merge_into_access_token_payload("invalidHandle", {"foo": "bar"})
242245
assert is_updated is False

0 commit comments

Comments
 (0)