Skip to content

Commit 6e71e91

Browse files
Merge pull request #453 from supertokens/feat/access-token-validation
feat: Add `validate_access_token` function to providers
2 parents eae5482 + 69b5d7b commit 6e71e91

File tree

11 files changed

+283
-30
lines changed

11 files changed

+283
-30
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
## [unreleased]
1010

11+
## [0.16.4] - 2023-10-05
12+
13+
- Add `validate_access_token` function to providers
14+
- This can be used to verify the access token received from providers.
15+
- Implemented `validate_access_token` for the Github provider.
16+
1117
## [0.16.3] - 2023-09-28
1218

1319
- Add Twitter provider for thirdparty login

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070

7171
setup(
7272
name="supertokens_python",
73-
version="0.16.3",
73+
version="0.16.4",
7474
author="SuperTokens",
7575
license="Apache 2.0",
7676
author_email="[email protected]",

supertokens_python/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
SUPPORTED_CDI_VERSIONS = ["3.0"]
17-
VERSION = "0.16.3"
17+
VERSION = "0.16.4"
1818
TELEMETRY = "/telemetry"
1919
USER_COUNT = "/users/count"
2020
USER_DELETE = "/user/remove"

supertokens_python/recipe/multitenancy/recipe_implementation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def parse_tenant_config(tenant: Dict[str, Any]) -> TenantConfigResponse:
105105
require_email=p.get("requireEmail", True),
106106
validate_id_token_payload=None,
107107
generate_fake_email=None,
108+
validate_access_token=None,
108109
)
109110
)
110111

supertokens_python/recipe/thirdparty/provider.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ def __init__(
176176
generate_fake_email: Optional[
177177
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
178178
] = None,
179+
validate_access_token: Optional[
180+
Callable[
181+
[str, ProviderConfigForClient, Dict[str, Any]],
182+
Awaitable[None],
183+
]
184+
] = None,
179185
):
180186
self.third_party_id = third_party_id
181187
self.name = name
@@ -192,6 +198,7 @@ def __init__(
192198
self.require_email = require_email
193199
self.validate_id_token_payload = validate_id_token_payload
194200
self.generate_fake_email = generate_fake_email
201+
self.validate_access_token = validate_access_token
195202

196203
def to_json(self) -> Dict[str, Any]:
197204
res = {
@@ -250,6 +257,12 @@ def __init__(
250257
generate_fake_email: Optional[
251258
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
252259
] = None,
260+
validate_access_token: Optional[
261+
Callable[
262+
[str, ProviderConfigForClient, Dict[str, Any]],
263+
Awaitable[None],
264+
]
265+
] = None,
253266
):
254267
ProviderClientConfig.__init__(
255268
self,
@@ -277,6 +290,7 @@ def __init__(
277290
require_email,
278291
validate_id_token_payload,
279292
generate_fake_email,
293+
validate_access_token,
280294
)
281295

282296
def to_json(self) -> Dict[str, Any]:
@@ -313,6 +327,12 @@ def __init__(
313327
generate_fake_email: Optional[
314328
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
315329
] = None,
330+
validate_access_token: Optional[
331+
Callable[
332+
[str, ProviderConfigForClient, Dict[str, Any]],
333+
Awaitable[None],
334+
]
335+
] = None,
316336
):
317337
super().__init__(
318338
third_party_id,
@@ -330,6 +350,7 @@ def __init__(
330350
require_email,
331351
validate_id_token_payload,
332352
generate_fake_email,
353+
validate_access_token,
333354
)
334355
self.clients = clients
335356

supertokens_python/recipe/thirdparty/providers/config_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def merge_config(
8787
user_info_map=config_from_static.user_info_map,
8888
generate_fake_email=config_from_static.generate_fake_email,
8989
validate_id_token_payload=config_from_static.validate_id_token_payload,
90+
validate_access_token=config_from_static.validate_access_token,
9091
)
9192

9293
if result.user_info_map is None:

supertokens_python/recipe/thirdparty/providers/custom.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def get_provider_config_for_client(
6060
require_email=config.require_email,
6161
validate_id_token_payload=config.validate_id_token_payload,
6262
generate_fake_email=config.generate_fake_email,
63+
validate_access_token=config.validate_access_token,
6364
)
6465

6566

@@ -375,7 +376,8 @@ async def exchange_auth_code_for_oauth_tokens(
375376
access_token_params["redirect_uri"] = DEV_OAUTH_REDIRECT_URL
376377
# Transformation needed for dev keys END
377378

378-
return await do_post_request(token_api_url, access_token_params)
379+
_, body = await do_post_request(token_api_url, access_token_params)
380+
return body
379381

380382
async def get_user_info(
381383
self, oauth_tokens: Dict[str, Any], user_context: Dict[str, Any]
@@ -402,25 +404,29 @@ async def get_user_info(
402404
user_context,
403405
)
404406

405-
if access_token is not None and self.config.token_endpoint is not None:
407+
if self.config.validate_access_token is not None and access_token is not None:
408+
await self.config.validate_access_token(
409+
access_token, self.config, user_context
410+
)
411+
412+
if access_token is not None and self.config.user_info_endpoint is not None:
406413
headers: Dict[str, str] = {"Authorization": f"Bearer {access_token}"}
407414
query_params: Dict[str, str] = {}
408415

409-
if self.config.user_info_endpoint is not None:
410-
if self.config.user_info_endpoint_headers is not None:
411-
headers = merge_into_dict(
412-
self.config.user_info_endpoint_headers, headers
413-
)
414-
415-
if self.config.user_info_endpoint_query_params is not None:
416-
query_params = merge_into_dict(
417-
self.config.user_info_endpoint_query_params, query_params
418-
)
416+
if self.config.user_info_endpoint_headers is not None:
417+
headers = merge_into_dict(
418+
self.config.user_info_endpoint_headers, headers
419+
)
419420

420-
raw_user_info_from_provider.from_user_info_api = await do_get_request(
421-
self.config.user_info_endpoint, query_params, headers
421+
if self.config.user_info_endpoint_query_params is not None:
422+
query_params = merge_into_dict(
423+
self.config.user_info_endpoint_query_params, query_params
422424
)
423425

426+
raw_user_info_from_provider.from_user_info_api = await do_get_request(
427+
self.config.user_info_endpoint, query_params, headers
428+
)
429+
424430
user_info_result = get_supertokens_user_info_result_from_raw_user_info(
425431
self.config, raw_user_info_from_provider
426432
)

supertokens_python/recipe/thirdparty/providers/github.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414
from __future__ import annotations
15+
16+
import base64
1517
from typing import Any, Dict, List, Optional
1618

17-
from supertokens_python.recipe.thirdparty.providers.utils import do_get_request
19+
from supertokens_python.recipe.thirdparty.providers.utils import (
20+
do_get_request,
21+
do_post_request,
22+
)
1823
from supertokens_python.recipe.thirdparty.types import UserInfo, UserInfoEmail
1924

2025
from .custom import GenericProvider, NewProvider
@@ -71,4 +76,29 @@ def Github(input: ProviderInput) -> Provider: # pylint: disable=redefined-built
7176
if input.config.token_endpoint is None:
7277
input.config.token_endpoint = "https://github.com/login/oauth/access_token"
7378

79+
if input.config.validate_access_token is None:
80+
input.config.validate_access_token = validate_access_token
81+
7482
return NewProvider(input, GithubImpl)
83+
84+
85+
async def validate_access_token(
86+
access_token: str, config: ProviderConfigForClient, _: Dict[str, Any]
87+
):
88+
client_secret = "" if config.client_secret is None else config.client_secret
89+
basic_auth_token = base64.b64encode(
90+
f"{config.client_id}:{client_secret}".encode()
91+
).decode()
92+
93+
url = f"https://api.github.com/applications/{config.client_id}/token"
94+
headers = {
95+
"Authorization": f"Basic {basic_auth_token}",
96+
"Content-Type": "application/json",
97+
}
98+
99+
status, body = await do_post_request(url, {"access_token": access_token}, headers)
100+
if status != 200:
101+
raise ValueError("Invalid access token")
102+
103+
if "app" not in body or body["app"].get("client_id") != config.client_id:
104+
raise ValueError("Access token does not belong to your application")

supertokens_python/recipe/thirdparty/providers/twitter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,12 @@ async def exchange_auth_code_for_oauth_tokens(
8484

8585
assert self.config.token_endpoint is not None
8686

87-
return await do_post_request(
87+
_, body = await do_post_request(
8888
self.config.token_endpoint,
8989
body_params=twitter_oauth_tokens_params,
9090
headers={"Authorization": f"Basic {auth_token}"},
9191
)
92+
return body
9293

9394

9495
def Twitter(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin

supertokens_python/recipe/thirdparty/providers/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Optional
1+
from typing import Any, Dict, Optional, Tuple
22

33
from httpx import AsyncClient
44

@@ -48,7 +48,7 @@ async def do_post_request(
4848
url: str,
4949
body_params: Optional[Dict[str, str]] = None,
5050
headers: Optional[Dict[str, str]] = None,
51-
) -> Dict[str, Any]:
51+
) -> Tuple[int, Dict[str, Any]]:
5252
if body_params is None:
5353
body_params = {}
5454
if headers is None:
@@ -62,4 +62,4 @@ async def do_post_request(
6262
log_debug_message(
6363
"Received response with status %s and body %s", res.status_code, res.text
6464
)
65-
return res.json()
65+
return res.status_code, res.json()

0 commit comments

Comments
 (0)