Skip to content

Commit 756bf40

Browse files
committed
feat: Add tests and types for multitenancy
1 parent 9f0a406 commit 756bf40

File tree

6 files changed

+325
-2
lines changed

6 files changed

+325
-2
lines changed

supertokens_python/recipe/emailpassword/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
self.time_joined = time_joined
3131
self.tenant_ids = tenant_ids
3232

33-
def __eq__(self, other: "User"): # type: ignore
33+
def __eq__(self, other: object):
3434
return (
3535
isinstance(other, self.__class__)
3636
and self.user_id == other.user_id

supertokens_python/recipe/passwordless/types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@ def __init__(
4040
self.time_joined = time_joined
4141
self.tenant_ids = tenant_ids
4242

43+
def __eq__(self, other: object) -> bool:
44+
return (
45+
isinstance(other, self.__class__)
46+
and self.user_id == other.user_id
47+
and self.email == other.email
48+
and self.phone_number == other.phone_number
49+
and self.time_joined == other.time_joined
50+
and self.tenant_ids == other.tenant_ids
51+
)
52+
4353

4454
class DeviceCode:
4555
def __init__(self, code_id: str, time_created: str, code_life_time: int):

supertokens_python/recipe/thirdparty/types.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ def __init__(self, third_party_user_id: str, third_party_id: str):
2121
self.user_id = third_party_user_id
2222
self.id = third_party_id
2323

24+
def __eq__(self, other: object) -> bool:
25+
return (
26+
isinstance(other, self.__class__)
27+
and self.user_id == other.user_id
28+
and self.id == other.id
29+
)
30+
2431

2532
class RawUserInfoFromProvider:
2633
def __init__(
@@ -47,6 +54,16 @@ def __init__(
4754
self.tenant_ids = tenant_ids
4855
self.third_party_info: ThirdPartyInfo = third_party_info
4956

57+
def __eq__(self, other: object) -> bool:
58+
return (
59+
isinstance(other, self.__class__)
60+
and self.user_id == other.user_id
61+
and self.email == other.email
62+
and self.time_joined == other.time_joined
63+
and self.tenant_ids == other.tenant_ids
64+
and self.third_party_info == other.third_party_info
65+
)
66+
5067

5168
class UserInfoEmail:
5269
def __init__(self, email: str, email_verified: bool):

supertokens_python/recipe/thirdpartypasswordless/recipeimplementation/implementation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ async def thirdparty_manually_create_or_update_user(
287287
if self.tp_manually_create_or_update_user is None:
288288
raise Exception("No thirdparty provider configured")
289289
return await self.tp_manually_create_or_update_user(
290-
third_party_id, third_party_user_id, tenant_id, email, user_context
290+
third_party_id, third_party_user_id, email, tenant_id, user_context
291291
)
292292

293293
async def thirdparty_get_provider(
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
from pytest import mark
15+
from supertokens_python.recipe import session, multitenancy, passwordless
16+
from supertokens_python import init
17+
from supertokens_python.recipe.multitenancy.asyncio import (
18+
create_or_update_tenant,
19+
)
20+
from supertokens_python.recipe.passwordless.asyncio import (
21+
create_code,
22+
consume_code,
23+
get_user_by_id,
24+
get_user_by_email,
25+
ConsumeCodeOkResult,
26+
)
27+
from supertokens_python.recipe.multitenancy.interfaces import TenantConfig
28+
29+
from tests.sessions.claims.utils import get_st_init_args
30+
from tests.utils import setup_function, teardown_function, setup_multitenancy_feature
31+
32+
33+
_ = setup_function
34+
_ = teardown_function
35+
36+
pytestmark = mark.asyncio
37+
38+
39+
async def test_multitenancy_functions():
40+
# test that different roles can be assigned for the same user for each tenant
41+
args = get_st_init_args(
42+
[
43+
session.init(),
44+
passwordless.init(
45+
contact_config=passwordless.ContactEmailOnlyConfig(),
46+
flow_type="USER_INPUT_CODE_AND_MAGIC_LINK",
47+
),
48+
multitenancy.init(),
49+
]
50+
)
51+
init(**args)
52+
setup_multitenancy_feature()
53+
54+
await create_or_update_tenant("t1", TenantConfig(passwordless_enabled=True))
55+
await create_or_update_tenant("t2", TenantConfig(passwordless_enabled=True))
56+
await create_or_update_tenant("t3", TenantConfig(passwordless_enabled=True))
57+
58+
code1 = await create_code(
59+
email="[email protected]", user_input_code="123456", tenant_id="t1"
60+
)
61+
code2 = await create_code(
62+
email="[email protected]", user_input_code="456789", tenant_id="t2"
63+
)
64+
code3 = await create_code(
65+
email="[email protected]", user_input_code="789123", tenant_id="t3"
66+
)
67+
68+
user1 = await consume_code(
69+
pre_auth_session_id=code1.pre_auth_session_id,
70+
device_id=code1.device_id,
71+
user_input_code="123456",
72+
tenant_id="t1",
73+
)
74+
user2 = await consume_code(
75+
pre_auth_session_id=code2.pre_auth_session_id,
76+
device_id=code2.device_id,
77+
user_input_code="456789",
78+
tenant_id="t2",
79+
)
80+
user3 = await consume_code(
81+
pre_auth_session_id=code3.pre_auth_session_id,
82+
device_id=code3.device_id,
83+
user_input_code="789123",
84+
tenant_id="t3",
85+
)
86+
87+
assert isinstance(user1, ConsumeCodeOkResult)
88+
assert isinstance(user2, ConsumeCodeOkResult)
89+
assert isinstance(user3, ConsumeCodeOkResult)
90+
91+
assert user1.user.user_id != user2.user.user_id
92+
assert user2.user.user_id != user3.user.user_id
93+
assert user3.user.user_id != user1.user.user_id
94+
95+
assert user1.user.tenant_ids == ["t1"]
96+
assert user2.user.tenant_ids == ["t2"]
97+
assert user3.user.tenant_ids == ["t3"]
98+
99+
# get user by id:
100+
g_user1 = await get_user_by_id(user1.user.user_id)
101+
g_user2 = await get_user_by_id(user2.user.user_id)
102+
g_user3 = await get_user_by_id(user3.user.user_id)
103+
104+
assert g_user1 == user1.user
105+
assert g_user2 == user2.user
106+
assert g_user3 == user3.user
107+
108+
# get user by email:
109+
by_email_user1 = await get_user_by_email("[email protected]", "t1")
110+
by_email_user2 = await get_user_by_email("[email protected]", "t2")
111+
by_email_user3 = await get_user_by_email("[email protected]", "t3")
112+
113+
assert by_email_user1 == user1.user
114+
assert by_email_user2 == user2.user
115+
assert by_email_user3 == user3.user

tests/thirdparty/test_multitenancy.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
from pytest import mark
15+
from supertokens_python.recipe import session, multitenancy, thirdparty
16+
from supertokens_python import init
17+
from supertokens_python.recipe.multitenancy.asyncio import (
18+
create_or_update_tenant,
19+
create_or_update_third_party_config,
20+
)
21+
from supertokens_python.recipe.thirdparty.asyncio import (
22+
manually_create_or_update_user,
23+
get_user_by_id,
24+
get_users_by_email,
25+
get_user_by_third_party_info,
26+
get_provider,
27+
)
28+
from supertokens_python.recipe.multitenancy.interfaces import TenantConfig
29+
30+
from tests.sessions.claims.utils import get_st_init_args
31+
from tests.utils import setup_function, teardown_function, setup_multitenancy_feature
32+
33+
34+
_ = setup_function
35+
_ = teardown_function
36+
37+
pytestmark = mark.asyncio
38+
39+
40+
async def test_multitenancy_functions():
41+
# test that different roles can be assigned for the same user for each tenant
42+
args = get_st_init_args([session.init(), thirdparty.init(), multitenancy.init()])
43+
init(**args)
44+
setup_multitenancy_feature()
45+
46+
await create_or_update_tenant("t1", TenantConfig(third_party_enabled=True))
47+
await create_or_update_tenant("t2", TenantConfig(third_party_enabled=True))
48+
await create_or_update_tenant("t3", TenantConfig(third_party_enabled=True))
49+
50+
# sign up:
51+
user1a = await manually_create_or_update_user(
52+
"google", "googleid1", "[email protected]", "t1"
53+
)
54+
user1b = await manually_create_or_update_user(
55+
"facebook", "fbid1", "[email protected]", "t1"
56+
)
57+
user2a = await manually_create_or_update_user(
58+
"google", "googleid1", "[email protected]", "t2"
59+
)
60+
user2b = await manually_create_or_update_user(
61+
"facebook", "fbid1", "[email protected]", "t2"
62+
)
63+
user3a = await manually_create_or_update_user(
64+
"google", "googleid1", "[email protected]", "t3"
65+
)
66+
user3b = await manually_create_or_update_user(
67+
"facebook", "fbid1", "[email protected]", "t3"
68+
)
69+
70+
assert user1a.user.tenant_ids == ["t1"]
71+
assert user1b.user.tenant_ids == ["t1"]
72+
assert user2a.user.tenant_ids == ["t2"]
73+
assert user2b.user.tenant_ids == ["t2"]
74+
assert user3a.user.tenant_ids == ["t3"]
75+
assert user3b.user.tenant_ids == ["t3"]
76+
77+
# get user by id:
78+
g_user1a = await get_user_by_id(user1a.user.user_id)
79+
g_user1b = await get_user_by_id(user1b.user.user_id)
80+
g_user2a = await get_user_by_id(user2a.user.user_id)
81+
g_user2b = await get_user_by_id(user2b.user.user_id)
82+
g_user3a = await get_user_by_id(user3a.user.user_id)
83+
g_user3b = await get_user_by_id(user3b.user.user_id)
84+
85+
assert g_user1a == user1a.user
86+
assert g_user1b == user1b.user
87+
assert g_user2a == user2a.user
88+
assert g_user2b == user2b.user
89+
assert g_user3a == user3a.user
90+
assert g_user3b == user3b.user
91+
92+
# get user by email:
93+
by_email_user1 = await get_users_by_email("[email protected]", "t1")
94+
by_email_user2 = await get_users_by_email("[email protected]", "t2")
95+
by_email_user3 = await get_users_by_email("[email protected]", "t3")
96+
97+
assert by_email_user1 == [user1a.user, user1b.user]
98+
assert by_email_user2 == [user2a.user, user2b.user]
99+
assert by_email_user3 == [user3a.user, user3b.user]
100+
101+
# get user by thirdparty id:
102+
g_user_by_tpid1a = await get_user_by_third_party_info("google", "googleid1", "t1")
103+
g_user_by_tpid1b = await get_user_by_third_party_info("facebook", "fbid1", "t1")
104+
g_user_by_tpid2a = await get_user_by_third_party_info("google", "googleid1", "t2")
105+
g_user_by_tpid2b = await get_user_by_third_party_info("facebook", "fbid1", "t2")
106+
g_user_by_tpid3a = await get_user_by_third_party_info("google", "googleid1", "t3")
107+
g_user_by_tpid3b = await get_user_by_third_party_info("facebook", "fbid1", "t3")
108+
109+
assert g_user_by_tpid1a == user1a.user
110+
assert g_user_by_tpid1b == user1b.user
111+
assert g_user_by_tpid2a == user2a.user
112+
assert g_user_by_tpid2b == user2b.user
113+
assert g_user_by_tpid3a == user3a.user
114+
assert g_user_by_tpid3b == user3b.user
115+
116+
117+
async def test_get_provider():
118+
args = get_st_init_args([session.init(), thirdparty.init(), multitenancy.init()])
119+
init(**args)
120+
setup_multitenancy_feature()
121+
122+
await create_or_update_tenant("t1", TenantConfig(third_party_enabled=True))
123+
await create_or_update_tenant("t1", TenantConfig(third_party_enabled=True))
124+
await create_or_update_tenant("t1", TenantConfig(third_party_enabled=True))
125+
126+
await create_or_update_third_party_config(
127+
"t1",
128+
thirdparty.ProviderConfig(
129+
"google", clients=[thirdparty.ProviderClientConfig("a")]
130+
),
131+
)
132+
await create_or_update_third_party_config(
133+
"t1",
134+
thirdparty.ProviderConfig(
135+
"facebook", clients=[thirdparty.ProviderClientConfig("a")]
136+
),
137+
)
138+
139+
await create_or_update_third_party_config(
140+
"t2",
141+
thirdparty.ProviderConfig(
142+
"facebook", clients=[thirdparty.ProviderClientConfig("a")]
143+
),
144+
)
145+
await create_or_update_third_party_config(
146+
"t2",
147+
thirdparty.ProviderConfig(
148+
"discord", clients=[thirdparty.ProviderClientConfig("a")]
149+
),
150+
)
151+
152+
await create_or_update_third_party_config(
153+
"t3",
154+
thirdparty.ProviderConfig(
155+
"discord", clients=[thirdparty.ProviderClientConfig("a")]
156+
),
157+
)
158+
await create_or_update_third_party_config(
159+
"t3",
160+
thirdparty.ProviderConfig(
161+
"linkedin", clients=[thirdparty.ProviderClientConfig("a")]
162+
),
163+
)
164+
165+
provider1 = await get_provider("google", None, "t1")
166+
assert provider1.provider.config.thirdparty_id == "google"
167+
168+
provider2 = await get_provider("facebook", None, "t1")
169+
assert provider2.provider.config.thirdparty_id == "facebook"
170+
171+
provider3 = await get_provider("facebook", None, "t2")
172+
assert provider3.provider.config.thirdparty_id == "facebook"
173+
174+
provider4 = await get_provider("discord", None, "t2")
175+
assert provider4.provider.config.thirdparty_id == "discord"
176+
177+
provider5 = await get_provider("discord", None, "t3")
178+
assert provider5.provider.config.thirdparty_id == "discord"
179+
180+
provider6 = await get_provider("linkedin", None, "t3")
181+
assert provider6.provider.config.thirdparty_id == "linkedin"

0 commit comments

Comments
 (0)