Skip to content

Commit 49946d2

Browse files
committed
tests: Add tests for multitenancy functions and improve interface
1 parent 73f41ae commit 49946d2

File tree

7 files changed

+189
-13
lines changed

7 files changed

+189
-13
lines changed

coreDriverInterfaceSupported.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"_comment": "contains a list of core-driver interfaces branch names that this core supports",
33
"versions": [
4-
"2.21"
4+
"3.0"
55
]
6-
}
6+
}

supertokens_python/recipe/multitenancy/api/implementation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ async def login_methods_get(
4848
provider_configs_from_core = tenant_config_res.third_party.providers
4949

5050
merged_providers = merge_providers_from_core_and_static(
51-
tenant_id, provider_configs_from_core, provider_inputs_from_static
51+
provider_configs_from_core, provider_inputs_from_static
5252
)
5353

5454
final_provider_list: List[ThirdPartyProvider] = []

supertokens_python/recipe/multitenancy/asyncio/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ async def list_all_tenants(
7676
async def create_or_update_third_party_config(
7777
tenant_id: Optional[str],
7878
config: ProviderConfig,
79-
skip_validation: Optional[bool],
79+
skip_validation: Optional[bool] = None,
8080
user_context: Optional[Dict[str, Any]] = None,
8181
) -> CreateOrUpdateThirdPartyConfigOkResult:
8282
if user_context is None:

supertokens_python/recipe/multitenancy/interfaces.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,6 @@ async def delete_third_party_config(
173173
) -> DeleteThirdPartyConfigOkResult:
174174
pass
175175

176-
# TODO: Should this be removed?
177-
@abstractmethod
178-
async def list_third_party_configs_for_third_party_id(
179-
self, third_party_id: str, user_context: Dict[str, Any]
180-
) -> ListThirdPartyConfigsForThirdPartyIdOkResult:
181-
pass
182-
183176
# user tenant association
184177
@abstractmethod
185178
async def associate_user_to_tenant(

supertokens_python/recipe/multitenancy/recipe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ async def login_methods_get(
226226
provider_configs_from_core = tenant_config.third_party.providers
227227

228228
merged_providers = merge_providers_from_core_and_static(
229-
tenant_id, provider_configs_from_core, provider_inputs_from_static
229+
provider_configs_from_core, provider_inputs_from_static
230230
)
231231

232232
final_provider_list: List[ThirdPartyProvider] = []

supertokens_python/recipe/multitenancy/syncio/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def list_all_tenants(user_context: Optional[Dict[str, Any]] = None):
6161
def create_or_update_third_party_config(
6262
tenant_id: Optional[str],
6363
config: ProviderConfig,
64-
skip_validation: Optional[bool],
64+
skip_validation: Optional[bool] = None,
6565
user_context: Optional[Dict[str, Any]] = None,
6666
):
6767
if user_context is None:
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
from fastapi import FastAPI
15+
from pytest import mark, fixture
16+
from starlette.testclient import TestClient
17+
18+
from supertokens_python import init
19+
from supertokens_python.framework.fastapi import get_middleware
20+
from supertokens_python.recipe import emailpassword, multitenancy, session
21+
from tests.utils import setup_function, teardown_function, get_st_init_args, start_st
22+
23+
_ = setup_function
24+
_ = teardown_function
25+
26+
pytestmark = mark.asyncio
27+
28+
from supertokens_python.recipe.multitenancy.asyncio import (
29+
create_or_update_tenant,
30+
list_all_tenants,
31+
get_tenant,
32+
delete_tenant,
33+
create_or_update_third_party_config,
34+
delete_third_party_config,
35+
associate_user_to_tenant,
36+
dissociate_user_from_tenant,
37+
)
38+
from supertokens_python.recipe.emailpassword.asyncio import sign_up, get_user_by_id
39+
from supertokens_python.recipe.emailpassword.interfaces import SignUpOkResult
40+
from supertokens_python.recipe.multitenancy.interfaces import TenantConfig
41+
from supertokens_python.recipe.thirdparty.provider import (
42+
ProviderConfig,
43+
ProviderClientConfig,
44+
)
45+
46+
47+
@fixture(scope="function")
48+
async def client():
49+
app = FastAPI()
50+
app.add_middleware(get_middleware())
51+
52+
return TestClient(app)
53+
54+
55+
async def test_tenant_crud():
56+
args = get_st_init_args([multitenancy.init()])
57+
init(**args)
58+
start_st()
59+
60+
await create_or_update_tenant("t1", TenantConfig(email_password_enabled=True))
61+
await create_or_update_tenant("t2", TenantConfig(passwordless_enabled=True))
62+
await create_or_update_tenant("t3", TenantConfig(third_party_enabled=True))
63+
64+
tenants = await list_all_tenants()
65+
assert len(tenants.tenants) == 4
66+
67+
t1_config = await get_tenant("t1")
68+
assert t1_config.email_password.enabled is True
69+
assert t1_config.passwordless.enabled is False
70+
assert t1_config.third_party.enabled is False
71+
assert t1_config.core_config == {}
72+
73+
t2_config = await get_tenant("t2")
74+
assert t2_config.email_password.enabled is False
75+
assert t2_config.passwordless.enabled is True
76+
assert t2_config.third_party.enabled is False
77+
assert t2_config.core_config == {}
78+
79+
t3_config = await get_tenant("t3")
80+
assert t3_config.email_password.enabled is False
81+
assert t3_config.passwordless.enabled is False
82+
assert t3_config.third_party.enabled is True
83+
assert t3_config.core_config == {}
84+
85+
# update tenant1 to add passwordless:
86+
await create_or_update_tenant("t1", TenantConfig(passwordless_enabled=True))
87+
t1_config = await get_tenant("t1")
88+
assert t1_config.email_password.enabled is True
89+
assert t1_config.passwordless.enabled is True
90+
assert t1_config.third_party.enabled is False
91+
assert t1_config.core_config == {}
92+
93+
# update tenant1 to add thirdparty:
94+
await create_or_update_tenant("t1", TenantConfig(third_party_enabled=True))
95+
t1_config = await get_tenant("t1")
96+
assert t1_config.email_password.enabled is True
97+
assert t1_config.passwordless.enabled is True
98+
assert t1_config.third_party.enabled is True
99+
assert t1_config.core_config == {}
100+
101+
# delete tenant2:
102+
await delete_tenant("t2")
103+
tenants = await list_all_tenants()
104+
assert len(tenants.tenants) == 2
105+
assert "t2" not in tenants.tenants
106+
107+
108+
async def test_tenant_thirdparty_config():
109+
args = get_st_init_args([multitenancy.init()])
110+
init(**args)
111+
start_st()
112+
113+
await create_or_update_tenant("t1", TenantConfig(email_password_enabled=True))
114+
await create_or_update_third_party_config(
115+
"t1",
116+
config=ProviderConfig(
117+
third_party_id="google",
118+
name="Google",
119+
clients=[ProviderClientConfig(client_id="abcd")],
120+
),
121+
)
122+
123+
tenant_config = await get_tenant("t1")
124+
125+
assert len(tenant_config.third_party.providers) == 1
126+
assert tenant_config.third_party.providers[0].third_party_id == "google"
127+
assert tenant_config.third_party.providers[0].clients is not None
128+
assert len(tenant_config.third_party.providers[0].clients) == 1
129+
assert tenant_config.third_party.providers[0].clients[0] == "abcd"
130+
131+
# update thirdparty config
132+
await create_or_update_third_party_config(
133+
"t1",
134+
ProviderConfig(
135+
third_party_id="google",
136+
name="Custom name",
137+
clients=[ProviderClientConfig(client_id="efgh")],
138+
),
139+
)
140+
141+
tenant_config = await get_tenant("t1")
142+
assert len(tenant_config.third_party.providers) == 1
143+
assert tenant_config.third_party.providers[0].third_party_id == "google"
144+
assert tenant_config.third_party.providers[0].name == "Custom name"
145+
assert tenant_config.third_party.providers[0].clients is not None
146+
assert len(tenant_config.third_party.providers[0].clients) == 1
147+
assert tenant_config.third_party.providers[0].clients[0] == "efgh"
148+
149+
# delete thirdparty config
150+
await delete_third_party_config("t1", "google")
151+
152+
tenant_config = await get_tenant("t1")
153+
assert len(tenant_config.third_party.providers) == 0
154+
155+
156+
async def test_user_association_and_disassociation_with_tenants():
157+
args = get_st_init_args([session.init(), emailpassword.init(), multitenancy.init()])
158+
init(**args)
159+
start_st()
160+
161+
await create_or_update_tenant("t1", TenantConfig(email_password_enabled=True))
162+
await create_or_update_tenant("t2", TenantConfig(passwordless_enabled=True))
163+
await create_or_update_tenant("t3", TenantConfig(third_party_enabled=True))
164+
165+
signup_response = await sign_up("[email protected]", "password1")
166+
assert isinstance(signup_response, SignUpOkResult)
167+
user_id = signup_response.user.user_id
168+
169+
await associate_user_to_tenant("t1", user_id)
170+
await associate_user_to_tenant("t2", user_id)
171+
await associate_user_to_tenant("t3", user_id)
172+
173+
user = await get_user_by_id(user_id)
174+
assert user is not None
175+
assert len(user.tenant_ids) == 4 # public + 3 tenants
176+
177+
await dissociate_user_from_tenant("t1", user_id)
178+
await dissociate_user_from_tenant("t2", user_id)
179+
await dissociate_user_from_tenant("t3", user_id)
180+
181+
user = await get_user_by_id(user_id)
182+
assert user is not None
183+
assert len(user.tenant_ids) == 1 # public only

0 commit comments

Comments
 (0)