Skip to content

Commit d679e31

Browse files
added success test case
changed order of positional arguments
1 parent 3bddf98 commit d679e31

File tree

5 files changed

+88
-86
lines changed

5 files changed

+88
-86
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
## [unreleased]
1010

11+
## [0.16.4] - 2023-10-05
12+
1113
- Add `validate_access_token` function to providers
1214
- This can be used to verify the access token received from providers.
1315
- Implemented `validate_access_token` for the Github provider.

supertokens_python/recipe/thirdparty/provider.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,15 @@ def __init__(
173173
Awaitable[None],
174174
]
175175
] = None,
176+
generate_fake_email: Optional[
177+
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
178+
] = None,
176179
validate_access_token: Optional[
177180
Callable[
178181
[str, ProviderConfigForClient, Dict[str, Any]],
179182
Awaitable[None],
180183
]
181184
] = None,
182-
generate_fake_email: Optional[
183-
Callable[[str, str, Dict[str, Any]], Awaitable[str]]
184-
] = None,
185185
):
186186
self.third_party_id = third_party_id
187187
self.name = name
@@ -289,8 +289,8 @@ def __init__(
289289
user_info_map,
290290
require_email,
291291
validate_id_token_payload,
292-
validate_access_token,
293292
generate_fake_email,
293+
validate_access_token,
294294
)
295295

296296
def to_json(self) -> Dict[str, Any]:
@@ -349,8 +349,8 @@ def __init__(
349349
user_info_map,
350350
require_email,
351351
validate_id_token_payload,
352-
validate_access_token,
353352
generate_fake_email,
353+
validate_access_token,
354354
)
355355
self.clients = clients
356356

supertokens_python/recipe/thirdparty/providers/custom.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,12 @@ async def get_user_info(
403403
user_context,
404404
)
405405

406-
if access_token is not None and self.config.token_endpoint is not None:
406+
if self.config.validate_access_token is not None and access_token is not None:
407+
await self.config.validate_access_token(
408+
access_token, self.config, user_context
409+
)
410+
411+
if access_token is not None and self.config.user_info_endpoint is not None:
407412
headers: Dict[str, str] = {"Authorization": f"Bearer {access_token}"}
408413
query_params: Dict[str, str] = {}
409414

@@ -422,11 +427,6 @@ async def get_user_info(
422427
self.config.user_info_endpoint, query_params, headers
423428
)
424429

425-
if self.config.validate_access_token is not None and access_token is not None:
426-
await self.config.validate_access_token(
427-
access_token, self.config, user_context
428-
)
429-
430430
user_info_result = get_supertokens_user_info_result_from_raw_user_info(
431431
self.config, raw_user_info_from_provider
432432
)

supertokens_python/recipe/thirdparty/providers/github.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from __future__ import annotations
1515

1616
import base64
17-
import json
1817
from typing import Any, Dict, List, Optional
1918

20-
import requests
21-
22-
from supertokens_python.recipe.thirdparty.providers.utils import do_get_request
19+
from supertokens_python.recipe.thirdparty.providers.utils import (
20+
do_get_request,
21+
do_post_request,
22+
)
2323
from supertokens_python.recipe.thirdparty.types import UserInfo, UserInfoEmail
2424

2525
from .custom import GenericProvider, NewProvider
@@ -95,14 +95,11 @@ async def validate_access_token(
9595
"Authorization": f"Basic {basic_auth_token}",
9696
"Content-Type": "application/json",
9797
}
98-
payload = json.dumps({"access_token": access_token})
99-
100-
resp = requests.post(url, headers=headers, data=payload)
10198

102-
if resp.status_code != 200:
99+
try:
100+
body = await do_post_request(url, {"access_token": access_token}, headers)
101+
except Exception:
103102
raise ValueError("Invalid access token")
104103

105-
body = resp.json()
106-
107-
if "app" not in body or body["app"]["client_id"] != config.client_id:
104+
if "app" not in body or body["app"].get("client_id") != config.client_id:
108105
raise ValueError("Access token does not belong to your application")

tests/thirdparty/test_thirdparty.py

Lines changed: 67 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import respx
77
from fastapi import FastAPI
88
from pytest import fixture, mark
9+
from pytest_mock import MockerFixture
910
from starlette.testclient import TestClient
1011

1112
from supertokens_python import init
@@ -106,18 +107,6 @@ async def exchange_auth_code_for_valid_oauth_tokens( # pylint: disable=unused-a
106107
}
107108

108109

109-
async def get_user_info( # pylint: disable=unused-argument
110-
oauth_tokens: Dict[str, Any],
111-
user_context: Dict[str, Any],
112-
) -> UserInfo:
113-
time = str(datetime.datetime.now())
114-
return UserInfo(
115-
"" + time,
116-
UserInfoEmail(f"johndoeprovidertest+{time}@supertokens.com", True),
117-
RawUserInfoFromProvider({}, {}),
118-
)
119-
120-
121110
async def exchange_auth_code_for_invalid_oauth_tokens( # pylint: disable=unused-argument
122111
redirect_uri_info: RedirectUriInfo,
123112
user_context: Dict[str, Any],
@@ -139,7 +128,6 @@ def get_custom_valid_token_provider(provider: Provider) -> Provider:
139128
provider.exchange_auth_code_for_oauth_tokens = (
140129
exchange_auth_code_for_valid_oauth_tokens
141130
)
142-
provider.get_user_info = get_user_info
143131
return provider
144132

145133

@@ -153,7 +141,9 @@ async def invalid_access_token( # pylint: disable=unused-argument
153141

154142

155143
async def valid_access_token( # pylint: disable=unused-argument
156-
access_token: str, config: ProviderConfig, user_context: Optional[Dict[str, Any]]
144+
access_token: str,
145+
config: ProviderConfigForClient,
146+
user_context: Optional[Dict[str, Any]],
157147
):
158148
if access_token == "accesstoken":
159149
return
@@ -210,53 +200,66 @@ async def test_signinup_when_validate_access_token_throws(fastapi_client: TestCl
210200
assert res.status_code == 500
211201

212202

213-
# async def test_signinup_works_when_validate_access_token_does_not_throw(fastapi_client: TestClient):
214-
# st_init_args = {
215-
# **st_init_common_args,
216-
# "recipe_list": [
217-
# session.init(),
218-
# thirdpartyemailpassword.init(
219-
# providers=[
220-
# ProviderInput(
221-
# config=ProviderConfig(
222-
# third_party_id="custom",
223-
# clients=[
224-
# ProviderClientConfig(
225-
# client_id="test",
226-
# client_secret="test-secret",
227-
# scope=["profile", "email"],
228-
# ),
229-
# ],
230-
# authorization_endpoint="https://example.com/oauth/authorize",
231-
# validate_access_token=valid_access_token,
232-
# authorization_endpoint_query_params={
233-
# "response_type": "token", # Changing an existing parameter
234-
# "response_mode": "form", # Adding a new parameter
235-
# "scope": None, # Removing a parameter
236-
# },
237-
# token_endpoint="https://example.com/oauth/token",
238-
# ),
239-
# override=get_custom_valid_token_provider
240-
# )
241-
# ]
242-
# ),
243-
# ],
244-
# }
245-
#
246-
# init(**st_init_args) # type: ignore
247-
# start_st()
248-
#
249-
# res = fastapi_client.post(
250-
# "/auth/signinup",
251-
# json={
252-
# "thirdPartyId": "custom",
253-
# "redirectURIInfo": {
254-
# "redirectURIOnProviderDashboard": "http://127.0.0.1/callback",
255-
# "redirectURIQueryParams": {
256-
# "code": "abcdefghj",
257-
# },
258-
# },
259-
# }
260-
# )
261-
# assert res.status_code == 200
262-
# assert res.json()["status"] == "OK"
203+
async def test_signinup_works_when_validate_access_token_does_not_throw(
204+
fastapi_client: TestClient, mocker: MockerFixture
205+
):
206+
time = str(datetime.datetime.now())
207+
mocker.patch(
208+
"supertokens_python.recipe.thirdparty.providers.custom.get_supertokens_user_info_result_from_raw_user_info",
209+
return_value=UserInfo(
210+
"" + time,
211+
UserInfoEmail(f"johndoeprovidertest+{time}@supertokens.com", True),
212+
RawUserInfoFromProvider({}, {}),
213+
),
214+
)
215+
216+
st_init_args = {
217+
**st_init_common_args,
218+
"recipe_list": [
219+
session.init(),
220+
thirdpartyemailpassword.init(
221+
providers=[
222+
ProviderInput(
223+
config=ProviderConfig(
224+
third_party_id="custom",
225+
clients=[
226+
ProviderClientConfig(
227+
client_id="test",
228+
client_secret="test-secret",
229+
scope=["profile", "email"],
230+
),
231+
],
232+
authorization_endpoint="https://example.com/oauth/authorize",
233+
validate_access_token=valid_access_token,
234+
authorization_endpoint_query_params={
235+
"response_type": "token", # Changing an existing parameter
236+
"response_mode": "form", # Adding a new parameter
237+
"scope": None, # Removing a parameter
238+
},
239+
token_endpoint="https://example.com/oauth/token",
240+
),
241+
override=get_custom_valid_token_provider,
242+
)
243+
]
244+
),
245+
],
246+
}
247+
248+
init(**st_init_args) # type: ignore
249+
start_st()
250+
251+
res = fastapi_client.post(
252+
"/auth/signinup",
253+
json={
254+
"thirdPartyId": "custom",
255+
"redirectURIInfo": {
256+
"redirectURIOnProviderDashboard": "http://127.0.0.1/callback",
257+
"redirectURIQueryParams": {
258+
"code": "abcdefghj",
259+
},
260+
},
261+
},
262+
)
263+
264+
assert res.status_code == 200
265+
assert res.json()["status"] == "OK"

0 commit comments

Comments
 (0)