Skip to content

fix: Clean up thirdparty recipe and make it consistent with other python recipes #395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion supertokens_python/recipe/openid/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import TYPE_CHECKING, List, Union, Optional, Any, Dict

from supertokens_python.querier import Querier
from supertokens_python.recipe.jwt import JWTRecipe

from .api.implementation import APIImplementation
from .api.open_id_discovery_configuration_get import open_id_discovery_configuration_get
Expand Down Expand Up @@ -49,6 +48,8 @@ def __init__(
issuer: Union[str, None] = None,
override: Union[InputOverrideConfig, None] = None,
):
from supertokens_python.recipe.jwt import JWTRecipe

super().__init__(recipe_id, app_info)
self.config = validate_and_normalise_user_input(app_info, issuer, override)
jwt_feature = None
Expand Down
1 change: 0 additions & 1 deletion supertokens_python/recipe/session/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
UnauthorisedError,
InvalidClaimsError,
)
from ... import AppInfo
from ...types import MaybeAwaitable

if TYPE_CHECKING:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
TokenTransferMethod,
get_required_claim_validators,
)
from supertokens_python.supertokens import AppInfo
from supertokens_python.types import MaybeAwaitable
from supertokens_python.utils import (
FRAMEWORKS,
Expand All @@ -60,10 +59,11 @@
normalise_http_method,
set_request_in_user_context_if_not_defined,
)
from supertokens_python import Supertokens
from supertokens_python.supertokens import Supertokens

if TYPE_CHECKING:
from supertokens_python.recipe.session.recipe import SessionRecipe
from supertokens_python.supertokens import AppInfo

LEGACY_ID_REFRESH_TOKEN_COOKIE_NAME = "sIdRefreshToken"

Expand Down
197 changes: 133 additions & 64 deletions supertokens_python/recipe/thirdparty/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@ def __init__(


class Provider:
def __init__(self, id: str): # pylint: disable=redefined-builtin
def __init__(
self, id: str, config: ProviderConfigForClient
): # pylint: disable=redefined-builtin
self.id = id
self.config = ProviderConfigForClientType("temp")
self.config = config

async def get_config_for_client_type( # pylint: disable=no-self-use
self, client_type: Optional[str], user_context: Dict[str, Any]
) -> ProviderConfigForClientType:
) -> ProviderConfigForClient:
_ = client_type
__ = user_context
raise NotImplementedError()
Expand Down Expand Up @@ -110,60 +112,6 @@ def to_json(self) -> Dict[str, Any]:
return {k: v for k, v in res.items() if v is not None}


class ProviderConfigForClientType:
def __init__(
self,
client_id: str,
client_secret: Optional[str] = None,
scope: Optional[List[str]] = None,
force_pkce: bool = False,
additional_config: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
authorization_endpoint: Optional[str] = None,
authorization_endpoint_query_params: Optional[
Dict[str, Union[str, None]]
] = None,
token_endpoint: Optional[str] = None,
token_endpoint_body_params: Optional[Dict[str, Union[str, None]]] = None,
user_info_endpoint: Optional[str] = None,
user_info_endpoint_query_params: Optional[Dict[str, Union[str, None]]] = None,
user_info_endpoint_headers: Optional[Dict[str, Union[str, None]]] = None,
jwks_uri: Optional[str] = None,
oidc_discovery_endpoint: Optional[str] = None,
user_info_map: Optional[UserInfoMap] = None,
require_email: bool = True,
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
validate_id_token_payload: Optional[
Callable[
[Dict[str, Any], ProviderConfigForClientType, Dict[str, Any]],
Awaitable[None],
]
] = None,
):
self.client_id = client_id
self.client_secret = client_secret
self.scope = scope
self.force_pkce = force_pkce
self.additional_config = additional_config

self.name = name
self.authorization_endpoint = authorization_endpoint
self.authorization_endpoint_query_params = authorization_endpoint_query_params
self.token_endpoint = token_endpoint
self.token_endpoint_body_params = token_endpoint_body_params
self.user_info_endpoint = user_info_endpoint
self.user_info_endpoint_query_params = user_info_endpoint_query_params
self.user_info_endpoint_headers = user_info_endpoint_headers
self.jwks_uri = jwks_uri
self.oidc_discovery_endpoint = oidc_discovery_endpoint
self.user_info_map = user_info_map
self.require_email = require_email
self.validate_id_token_payload = validate_id_token_payload
self.generate_fake_email = generate_fake_email


class UserFields:
def __init__(
self,
Expand Down Expand Up @@ -201,12 +149,11 @@ def to_json(self) -> Dict[str, Any]:
}


class ProviderConfig:
class CommonProviderConfig:
def __init__(
self,
third_party_id: str,
name: Optional[str] = None,
clients: Optional[List[ProviderClientConfig]] = None,
authorization_endpoint: Optional[str] = None,
authorization_endpoint_query_params: Optional[
Dict[str, Union[str, None]]
Expand All @@ -222,7 +169,7 @@ def __init__(
require_email: bool = True,
validate_id_token_payload: Optional[
Callable[
[Dict[str, Any], ProviderConfigForClientType, Dict[str, Any]],
[Dict[str, Any], ProviderConfigForClient, Dict[str, Any]],
Awaitable[None],
]
] = None,
Expand All @@ -232,7 +179,6 @@ def __init__(
):
self.third_party_id = third_party_id
self.name = name
self.clients = clients
self.authorization_endpoint = authorization_endpoint
self.authorization_endpoint_query_params = authorization_endpoint_query_params
self.token_endpoint = token_endpoint
Expand All @@ -251,9 +197,6 @@ def to_json(self) -> Dict[str, Any]:
res = {
"thirdPartyId": self.third_party_id,
"name": self.name,
"clients": [c.to_json() for c in self.clients]
if self.clients is not None
else [],
"authorizationEndpoint": self.authorization_endpoint,
"authorizationEndpointQueryParams": self.authorization_endpoint_query_params,
"tokenEndpoint": self.token_endpoint,
Expand All @@ -272,6 +215,132 @@ def to_json(self) -> Dict[str, Any]:
return {k: v for k, v in res.items() if v is not None}


class ProviderConfigForClient(ProviderClientConfig, CommonProviderConfig):
def __init__(
self,
# ProviderClientConfig:
client_id: str,
client_secret: Optional[str] = None,
client_type: Optional[str] = None,
scope: Optional[List[str]] = None,
force_pkce: bool = False,
additional_config: Optional[Dict[str, Any]] = None,
# CommonProviderConfig:
name: Optional[str] = None,
authorization_endpoint: Optional[str] = None,
authorization_endpoint_query_params: Optional[
Dict[str, Union[str, None]]
] = None,
token_endpoint: Optional[str] = None,
token_endpoint_body_params: Optional[Dict[str, Union[str, None]]] = None,
user_info_endpoint: Optional[str] = None,
user_info_endpoint_query_params: Optional[Dict[str, Union[str, None]]] = None,
user_info_endpoint_headers: Optional[Dict[str, Union[str, None]]] = None,
jwks_uri: Optional[str] = None,
oidc_discovery_endpoint: Optional[str] = None,
user_info_map: Optional[UserInfoMap] = None,
require_email: bool = True,
validate_id_token_payload: Optional[
Callable[
[Dict[str, Any], ProviderConfigForClient, Dict[str, Any]],
Awaitable[None],
]
] = None,
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
):
ProviderClientConfig.__init__(
self,
client_id,
client_secret,
client_type,
scope,
force_pkce,
additional_config,
)
CommonProviderConfig.__init__(
self,
"temp",
name,
authorization_endpoint,
authorization_endpoint_query_params,
token_endpoint,
token_endpoint_body_params,
user_info_endpoint,
user_info_endpoint_query_params,
user_info_endpoint_headers,
jwks_uri,
oidc_discovery_endpoint,
user_info_map,
require_email,
validate_id_token_payload,
generate_fake_email,
)

def to_json(self) -> Dict[str, Any]:
d1 = ProviderClientConfig.to_json(self)
d2 = CommonProviderConfig.to_json(self)
return {**d1, **d2}


class ProviderConfig(CommonProviderConfig):
def __init__(
self,
third_party_id: str,
name: Optional[str] = None,
clients: Optional[List[ProviderClientConfig]] = None,
authorization_endpoint: Optional[str] = None,
authorization_endpoint_query_params: Optional[
Dict[str, Union[str, None]]
] = None,
token_endpoint: Optional[str] = None,
token_endpoint_body_params: Optional[Dict[str, Union[str, None]]] = None,
user_info_endpoint: Optional[str] = None,
user_info_endpoint_query_params: Optional[Dict[str, Union[str, None]]] = None,
user_info_endpoint_headers: Optional[Dict[str, Union[str, None]]] = None,
jwks_uri: Optional[str] = None,
oidc_discovery_endpoint: Optional[str] = None,
user_info_map: Optional[UserInfoMap] = None,
require_email: bool = True,
validate_id_token_payload: Optional[
Callable[
[Dict[str, Any], ProviderConfigForClient, Dict[str, Any]],
Awaitable[None],
]
] = None,
generate_fake_email: Optional[
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
] = None,
):
super().__init__(
third_party_id,
name,
authorization_endpoint,
authorization_endpoint_query_params,
token_endpoint,
token_endpoint_body_params,
user_info_endpoint,
user_info_endpoint_query_params,
user_info_endpoint_headers,
jwks_uri,
oidc_discovery_endpoint,
user_info_map,
require_email,
validate_id_token_payload,
generate_fake_email,
)
self.clients = clients

def to_json(self) -> Dict[str, Any]:
d = CommonProviderConfig.to_json(self)

if self.clients is not None:
d["clients"] = [c.to_json() for c in self.clients]

return d


class ProviderInput:
def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .custom import GenericProvider, NewProvider
from ..provider import (
Provider,
ProviderConfigForClientType,
ProviderConfigForClient,
ProviderInput,
UserFields,
UserInfoMap,
Expand All @@ -27,7 +27,7 @@
class ActiveDirectoryImpl(GenericProvider):
async def get_config_for_client_type(
self, client_type: Optional[str], user_context: Dict[str, Any]
) -> ProviderConfigForClientType:
) -> ProviderConfigForClient:
config = await super().get_config_for_client_type(client_type, user_context)
if config.oidc_discovery_endpoint is None:
if (
Expand Down
6 changes: 3 additions & 3 deletions supertokens_python/recipe/thirdparty/providers/apple.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
from time import time

from .custom import GenericProvider, NewProvider
from ..provider import Provider, ProviderConfigForClientType, ProviderInput
from ..provider import Provider, ProviderConfigForClient, ProviderInput
from .utils import get_actual_client_id_from_development_client_id


class AppleImpl(GenericProvider):
async def get_config_for_client_type(
self, client_type: Optional[str], user_context: Dict[str, Any]
) -> ProviderConfigForClientType:
) -> ProviderConfigForClient:
config = await super().get_config_for_client_type(client_type, user_context)

if config.scope is None:
Expand All @@ -38,7 +38,7 @@ async def get_config_for_client_type(
return config

async def _get_client_secret( # pylint: disable=no-self-use
self, config: ProviderConfigForClientType
self, config: ProviderConfigForClient
) -> str:
if (
config.additional_config is None
Expand Down
4 changes: 2 additions & 2 deletions supertokens_python/recipe/thirdparty/providers/boxy_saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .custom import GenericProvider, NewProvider
from ..provider import (
Provider,
ProviderConfigForClientType,
ProviderConfigForClient,
ProviderInput,
UserFields,
UserInfoMap,
Expand All @@ -27,7 +27,7 @@
class BoxySAMLImpl(GenericProvider):
async def get_config_for_client_type(
self, client_type: Optional[str], user_context: Dict[str, Any]
) -> ProviderConfigForClientType:
) -> ProviderConfigForClient:
config = await super().get_config_for_client_type(client_type, user_context)
if (
config.additional_config is None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from ..provider import (
ProviderConfig,
ProviderConfigForClientType,
ProviderConfigForClient,
ProviderInput,
Provider,
UserFields,
Expand Down Expand Up @@ -216,8 +216,8 @@ async def get_oidc_discovery_info(issuer: str):


async def discover_oidc_endpoints(
config: ProviderConfigForClientType,
) -> ProviderConfigForClientType:
config: ProviderConfigForClient,
) -> ProviderConfigForClient:
if config.oidc_discovery_endpoint is None:
return config

Expand Down
Loading