Skip to content

Commit 6083a1b

Browse files
Merge pull request #395 from supertokens/fix/thirdparty-config-mt
fix: Clean up thirdparty recipe and make it consistent with other python recipes
2 parents b262e99 + 679eedc commit 6083a1b

File tree

21 files changed

+231
-170
lines changed

21 files changed

+231
-170
lines changed

supertokens_python/recipe/openid/recipe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import TYPE_CHECKING, List, Union, Optional, Any, Dict
1818

1919
from supertokens_python.querier import Querier
20-
from supertokens_python.recipe.jwt import JWTRecipe
2120

2221
from .api.implementation import APIImplementation
2322
from .api.open_id_discovery_configuration_get import open_id_discovery_configuration_get
@@ -49,6 +48,8 @@ def __init__(
4948
issuer: Union[str, None] = None,
5049
override: Union[InputOverrideConfig, None] = None,
5150
):
51+
from supertokens_python.recipe.jwt import JWTRecipe
52+
5253
super().__init__(recipe_id, app_info)
5354
self.config = validate_and_normalise_user_input(app_info, issuer, override)
5455
jwt_feature = None

supertokens_python/recipe/session/recipe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
UnauthorisedError,
2929
InvalidClaimsError,
3030
)
31-
from ... import AppInfo
3231
from ...types import MaybeAwaitable
3332

3433
if TYPE_CHECKING:

supertokens_python/recipe/session/session_request_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
TokenTransferMethod,
5252
get_required_claim_validators,
5353
)
54-
from supertokens_python.supertokens import AppInfo
5554
from supertokens_python.types import MaybeAwaitable
5655
from supertokens_python.utils import (
5756
FRAMEWORKS,
@@ -60,10 +59,11 @@
6059
normalise_http_method,
6160
set_request_in_user_context_if_not_defined,
6261
)
63-
from supertokens_python import Supertokens
62+
from supertokens_python.supertokens import Supertokens
6463

6564
if TYPE_CHECKING:
6665
from supertokens_python.recipe.session.recipe import SessionRecipe
66+
from supertokens_python.supertokens import AppInfo
6767

6868
LEGACY_ID_REFRESH_TOKEN_COOKIE_NAME = "sIdRefreshToken"
6969

supertokens_python/recipe/thirdparty/provider.py

Lines changed: 133 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ def __init__(
4141

4242

4343
class Provider:
44-
def __init__(self, id: str): # pylint: disable=redefined-builtin
44+
def __init__(
45+
self, id: str, config: ProviderConfigForClient
46+
): # pylint: disable=redefined-builtin
4547
self.id = id
46-
self.config = ProviderConfigForClientType("temp")
48+
self.config = config
4749

4850
async def get_config_for_client_type( # pylint: disable=no-self-use
4951
self, client_type: Optional[str], user_context: Dict[str, Any]
50-
) -> ProviderConfigForClientType:
52+
) -> ProviderConfigForClient:
5153
_ = client_type
5254
__ = user_context
5355
raise NotImplementedError()
@@ -110,60 +112,6 @@ def to_json(self) -> Dict[str, Any]:
110112
return {k: v for k, v in res.items() if v is not None}
111113

112114

113-
class ProviderConfigForClientType:
114-
def __init__(
115-
self,
116-
client_id: str,
117-
client_secret: Optional[str] = None,
118-
scope: Optional[List[str]] = None,
119-
force_pkce: bool = False,
120-
additional_config: Optional[Dict[str, Any]] = None,
121-
name: Optional[str] = None,
122-
authorization_endpoint: Optional[str] = None,
123-
authorization_endpoint_query_params: Optional[
124-
Dict[str, Union[str, None]]
125-
] = None,
126-
token_endpoint: Optional[str] = None,
127-
token_endpoint_body_params: Optional[Dict[str, Union[str, None]]] = None,
128-
user_info_endpoint: Optional[str] = None,
129-
user_info_endpoint_query_params: Optional[Dict[str, Union[str, None]]] = None,
130-
user_info_endpoint_headers: Optional[Dict[str, Union[str, None]]] = None,
131-
jwks_uri: Optional[str] = None,
132-
oidc_discovery_endpoint: Optional[str] = None,
133-
user_info_map: Optional[UserInfoMap] = None,
134-
require_email: bool = True,
135-
generate_fake_email: Optional[
136-
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
137-
] = None,
138-
validate_id_token_payload: Optional[
139-
Callable[
140-
[Dict[str, Any], ProviderConfigForClientType, Dict[str, Any]],
141-
Awaitable[None],
142-
]
143-
] = None,
144-
):
145-
self.client_id = client_id
146-
self.client_secret = client_secret
147-
self.scope = scope
148-
self.force_pkce = force_pkce
149-
self.additional_config = additional_config
150-
151-
self.name = name
152-
self.authorization_endpoint = authorization_endpoint
153-
self.authorization_endpoint_query_params = authorization_endpoint_query_params
154-
self.token_endpoint = token_endpoint
155-
self.token_endpoint_body_params = token_endpoint_body_params
156-
self.user_info_endpoint = user_info_endpoint
157-
self.user_info_endpoint_query_params = user_info_endpoint_query_params
158-
self.user_info_endpoint_headers = user_info_endpoint_headers
159-
self.jwks_uri = jwks_uri
160-
self.oidc_discovery_endpoint = oidc_discovery_endpoint
161-
self.user_info_map = user_info_map
162-
self.require_email = require_email
163-
self.validate_id_token_payload = validate_id_token_payload
164-
self.generate_fake_email = generate_fake_email
165-
166-
167115
class UserFields:
168116
def __init__(
169117
self,
@@ -201,12 +149,11 @@ def to_json(self) -> Dict[str, Any]:
201149
}
202150

203151

204-
class ProviderConfig:
152+
class CommonProviderConfig:
205153
def __init__(
206154
self,
207155
third_party_id: str,
208156
name: Optional[str] = None,
209-
clients: Optional[List[ProviderClientConfig]] = None,
210157
authorization_endpoint: Optional[str] = None,
211158
authorization_endpoint_query_params: Optional[
212159
Dict[str, Union[str, None]]
@@ -222,7 +169,7 @@ def __init__(
222169
require_email: bool = True,
223170
validate_id_token_payload: Optional[
224171
Callable[
225-
[Dict[str, Any], ProviderConfigForClientType, Dict[str, Any]],
172+
[Dict[str, Any], ProviderConfigForClient, Dict[str, Any]],
226173
Awaitable[None],
227174
]
228175
] = None,
@@ -232,7 +179,6 @@ def __init__(
232179
):
233180
self.third_party_id = third_party_id
234181
self.name = name
235-
self.clients = clients
236182
self.authorization_endpoint = authorization_endpoint
237183
self.authorization_endpoint_query_params = authorization_endpoint_query_params
238184
self.token_endpoint = token_endpoint
@@ -251,9 +197,6 @@ def to_json(self) -> Dict[str, Any]:
251197
res = {
252198
"thirdPartyId": self.third_party_id,
253199
"name": self.name,
254-
"clients": [c.to_json() for c in self.clients]
255-
if self.clients is not None
256-
else [],
257200
"authorizationEndpoint": self.authorization_endpoint,
258201
"authorizationEndpointQueryParams": self.authorization_endpoint_query_params,
259202
"tokenEndpoint": self.token_endpoint,
@@ -272,6 +215,132 @@ def to_json(self) -> Dict[str, Any]:
272215
return {k: v for k, v in res.items() if v is not None}
273216

274217

218+
class ProviderConfigForClient(ProviderClientConfig, CommonProviderConfig):
219+
def __init__(
220+
self,
221+
# ProviderClientConfig:
222+
client_id: str,
223+
client_secret: Optional[str] = None,
224+
client_type: Optional[str] = None,
225+
scope: Optional[List[str]] = None,
226+
force_pkce: bool = False,
227+
additional_config: Optional[Dict[str, Any]] = None,
228+
# CommonProviderConfig:
229+
name: Optional[str] = None,
230+
authorization_endpoint: Optional[str] = None,
231+
authorization_endpoint_query_params: Optional[
232+
Dict[str, Union[str, None]]
233+
] = None,
234+
token_endpoint: Optional[str] = None,
235+
token_endpoint_body_params: Optional[Dict[str, Union[str, None]]] = None,
236+
user_info_endpoint: Optional[str] = None,
237+
user_info_endpoint_query_params: Optional[Dict[str, Union[str, None]]] = None,
238+
user_info_endpoint_headers: Optional[Dict[str, Union[str, None]]] = None,
239+
jwks_uri: Optional[str] = None,
240+
oidc_discovery_endpoint: Optional[str] = None,
241+
user_info_map: Optional[UserInfoMap] = None,
242+
require_email: bool = True,
243+
validate_id_token_payload: Optional[
244+
Callable[
245+
[Dict[str, Any], ProviderConfigForClient, Dict[str, Any]],
246+
Awaitable[None],
247+
]
248+
] = None,
249+
generate_fake_email: Optional[
250+
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
251+
] = None,
252+
):
253+
ProviderClientConfig.__init__(
254+
self,
255+
client_id,
256+
client_secret,
257+
client_type,
258+
scope,
259+
force_pkce,
260+
additional_config,
261+
)
262+
CommonProviderConfig.__init__(
263+
self,
264+
"temp",
265+
name,
266+
authorization_endpoint,
267+
authorization_endpoint_query_params,
268+
token_endpoint,
269+
token_endpoint_body_params,
270+
user_info_endpoint,
271+
user_info_endpoint_query_params,
272+
user_info_endpoint_headers,
273+
jwks_uri,
274+
oidc_discovery_endpoint,
275+
user_info_map,
276+
require_email,
277+
validate_id_token_payload,
278+
generate_fake_email,
279+
)
280+
281+
def to_json(self) -> Dict[str, Any]:
282+
d1 = ProviderClientConfig.to_json(self)
283+
d2 = CommonProviderConfig.to_json(self)
284+
return {**d1, **d2}
285+
286+
287+
class ProviderConfig(CommonProviderConfig):
288+
def __init__(
289+
self,
290+
third_party_id: str,
291+
name: Optional[str] = None,
292+
clients: Optional[List[ProviderClientConfig]] = None,
293+
authorization_endpoint: Optional[str] = None,
294+
authorization_endpoint_query_params: Optional[
295+
Dict[str, Union[str, None]]
296+
] = None,
297+
token_endpoint: Optional[str] = None,
298+
token_endpoint_body_params: Optional[Dict[str, Union[str, None]]] = None,
299+
user_info_endpoint: Optional[str] = None,
300+
user_info_endpoint_query_params: Optional[Dict[str, Union[str, None]]] = None,
301+
user_info_endpoint_headers: Optional[Dict[str, Union[str, None]]] = None,
302+
jwks_uri: Optional[str] = None,
303+
oidc_discovery_endpoint: Optional[str] = None,
304+
user_info_map: Optional[UserInfoMap] = None,
305+
require_email: bool = True,
306+
validate_id_token_payload: Optional[
307+
Callable[
308+
[Dict[str, Any], ProviderConfigForClient, Dict[str, Any]],
309+
Awaitable[None],
310+
]
311+
] = None,
312+
generate_fake_email: Optional[
313+
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
314+
] = None,
315+
):
316+
super().__init__(
317+
third_party_id,
318+
name,
319+
authorization_endpoint,
320+
authorization_endpoint_query_params,
321+
token_endpoint,
322+
token_endpoint_body_params,
323+
user_info_endpoint,
324+
user_info_endpoint_query_params,
325+
user_info_endpoint_headers,
326+
jwks_uri,
327+
oidc_discovery_endpoint,
328+
user_info_map,
329+
require_email,
330+
validate_id_token_payload,
331+
generate_fake_email,
332+
)
333+
self.clients = clients
334+
335+
def to_json(self) -> Dict[str, Any]:
336+
d = CommonProviderConfig.to_json(self)
337+
338+
if self.clients is not None:
339+
d["clients"] = [c.to_json() for c in self.clients]
340+
341+
return d
342+
343+
275344
class ProviderInput:
276345
def __init__(
277346
self,

supertokens_python/recipe/thirdparty/providers/active_directory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .custom import GenericProvider, NewProvider
1818
from ..provider import (
1919
Provider,
20-
ProviderConfigForClientType,
20+
ProviderConfigForClient,
2121
ProviderInput,
2222
UserFields,
2323
UserInfoMap,
@@ -27,7 +27,7 @@
2727
class ActiveDirectoryImpl(GenericProvider):
2828
async def get_config_for_client_type(
2929
self, client_type: Optional[str], user_context: Dict[str, Any]
30-
) -> ProviderConfigForClientType:
30+
) -> ProviderConfigForClient:
3131
config = await super().get_config_for_client_type(client_type, user_context)
3232
if config.oidc_discovery_endpoint is None:
3333
if (

supertokens_python/recipe/thirdparty/providers/apple.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
from time import time
2020

2121
from .custom import GenericProvider, NewProvider
22-
from ..provider import Provider, ProviderConfigForClientType, ProviderInput
22+
from ..provider import Provider, ProviderConfigForClient, ProviderInput
2323
from .utils import get_actual_client_id_from_development_client_id
2424

2525

2626
class AppleImpl(GenericProvider):
2727
async def get_config_for_client_type(
2828
self, client_type: Optional[str], user_context: Dict[str, Any]
29-
) -> ProviderConfigForClientType:
29+
) -> ProviderConfigForClient:
3030
config = await super().get_config_for_client_type(client_type, user_context)
3131

3232
if config.scope is None:
@@ -38,7 +38,7 @@ async def get_config_for_client_type(
3838
return config
3939

4040
async def _get_client_secret( # pylint: disable=no-self-use
41-
self, config: ProviderConfigForClientType
41+
self, config: ProviderConfigForClient
4242
) -> str:
4343
if (
4444
config.additional_config is None

supertokens_python/recipe/thirdparty/providers/boxy_saml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .custom import GenericProvider, NewProvider
1818
from ..provider import (
1919
Provider,
20-
ProviderConfigForClientType,
20+
ProviderConfigForClient,
2121
ProviderInput,
2222
UserFields,
2323
UserInfoMap,
@@ -27,7 +27,7 @@
2727
class BoxySAMLImpl(GenericProvider):
2828
async def get_config_for_client_type(
2929
self, client_type: Optional[str], user_context: Dict[str, Any]
30-
) -> ProviderConfigForClientType:
30+
) -> ProviderConfigForClient:
3131
config = await super().get_config_for_client_type(client_type, user_context)
3232
if (
3333
config.additional_config is None

supertokens_python/recipe/thirdparty/providers/config_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from ..provider import (
1919
ProviderConfig,
20-
ProviderConfigForClientType,
20+
ProviderConfigForClient,
2121
ProviderInput,
2222
Provider,
2323
UserFields,
@@ -216,8 +216,8 @@ async def get_oidc_discovery_info(issuer: str):
216216

217217

218218
async def discover_oidc_endpoints(
219-
config: ProviderConfigForClientType,
220-
) -> ProviderConfigForClientType:
219+
config: ProviderConfigForClient,
220+
) -> ProviderConfigForClient:
221221
if config.oidc_discovery_endpoint is None:
222222
return config
223223

0 commit comments

Comments
 (0)