Skip to content

Commit 51378bc

Browse files
fix: Added validate_access_token for github provider
1 parent 3429fef commit 51378bc

File tree

2 files changed

+298
-0
lines changed

2 files changed

+298
-0
lines changed

supertokens_python/recipe/thirdparty/providers/github.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,13 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414
from __future__ import annotations
15+
16+
import base64
17+
import json
1518
from typing import Any, Dict, List, Optional
1619

20+
import requests
21+
1722
from supertokens_python.recipe.thirdparty.providers.utils import do_get_request
1823
from supertokens_python.recipe.thirdparty.types import UserInfo, UserInfoEmail
1924

@@ -71,4 +76,35 @@ def Github(input: ProviderInput) -> Provider: # pylint: disable=redefined-built
7176
if input.config.token_endpoint is None:
7277
input.config.token_endpoint = "https://github.com/login/oauth/access_token"
7378

79+
if input.config.validate_access_token is None:
80+
input.config.validate_access_token = validate_access_token
81+
7482
return NewProvider(input, GithubImpl)
83+
84+
85+
async def validate_access_token(
86+
access_token: str, config: ProviderConfigForClient, _: Dict[str, Any]
87+
):
88+
client_secret = "" if config.client_secret is None else config.client_secret
89+
basic_auth_token = base64.b64encode(
90+
f"{config.client_id}:{client_secret}".encode()
91+
).decode()
92+
93+
# POST request to get applications response
94+
url = f"https://api.github.com/applications/{config.client_id}/token"
95+
headers = {
96+
"Authorization": f"Basic {basic_auth_token}",
97+
"Content-Type": "application/json",
98+
}
99+
payload = json.dumps({"access_token": access_token})
100+
101+
resp = requests.post(url, headers=headers, data=payload)
102+
103+
# Error handling and validation
104+
if resp.status_code != 200:
105+
raise ValueError("Invalid access token")
106+
107+
body = resp.json()
108+
109+
if "app" not in body or body["app"]["client_id"] != config.client_id:
110+
raise ValueError("Access token does not belong to your application")

tests/thirdparty/test_providers.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
import datetime
2+
import json
3+
from base64 import b64encode
4+
from typing import Dict, Any, Optional
5+
6+
import respx
7+
from fastapi import FastAPI
8+
from pytest import fixture, mark
9+
from starlette.testclient import TestClient
10+
11+
from supertokens_python import init
12+
from supertokens_python.framework.fastapi import get_middleware
13+
from supertokens_python.recipe import session, thirdparty
14+
from supertokens_python.recipe import thirdpartyemailpassword
15+
from supertokens_python.recipe.thirdparty.provider import (
16+
ProviderClientConfig,
17+
ProviderConfig,
18+
ProviderInput,
19+
Provider,
20+
RedirectUriInfo,
21+
ProviderConfigForClient,
22+
)
23+
from supertokens_python.recipe.thirdparty.types import (
24+
UserInfo,
25+
UserInfoEmail,
26+
RawUserInfoFromProvider,
27+
)
28+
from tests.utils import (
29+
setup_function,
30+
teardown_function,
31+
start_st,
32+
st_init_common_args,
33+
)
34+
35+
_ = setup_function # type:ignore
36+
_ = teardown_function # type:ignore
37+
_ = start_st # type:ignore
38+
39+
pytestmark = mark.asyncio
40+
41+
respx_mock = respx.MockRouter
42+
43+
44+
@fixture(scope="function")
45+
async def fastapi_client():
46+
app = FastAPI()
47+
app.add_middleware(get_middleware())
48+
49+
return TestClient(app, raise_server_exceptions=False)
50+
51+
52+
async def test_thirdpary_parsing_works(fastapi_client: TestClient):
53+
st_init_args = {
54+
**st_init_common_args,
55+
"recipe_list": [
56+
session.init(),
57+
thirdparty.init(
58+
sign_in_and_up_feature=thirdparty.SignInAndUpFeature(
59+
providers=[
60+
thirdparty.ProviderInput(
61+
config=thirdparty.ProviderConfig(
62+
third_party_id="apple",
63+
clients=[
64+
thirdparty.ProviderClientConfig(
65+
client_id="4398792-io.supertokens.example.service",
66+
additional_config={
67+
"keyId": "7M48Y4RYDL",
68+
"teamId": "YWQCXGJRJL",
69+
"privateKey": "-----BEGIN PRIVATE KEY-----\nMIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgu8gXs+XYkqXD6Ala9Sf/iJXzhbwcoG5dMh1OonpdJUmgCgYIKoZIzj0DAQehRANCAASfrvlFbFCYqn3I2zeknYXLwtH30JuOKestDbSfZYxZNMqhF/OzdZFTV0zc5u5s3eN+oCWbnvl0hM+9IW0UlkdA\n-----END PRIVATE KEY-----",
70+
},
71+
),
72+
],
73+
)
74+
),
75+
]
76+
)
77+
),
78+
],
79+
}
80+
init(**st_init_args) # type: ignore
81+
start_st()
82+
83+
state = b64encode(
84+
json.dumps({"frontendRedirectURI": "http://localhost:3000/redirect"}).encode()
85+
).decode()
86+
code = "testing"
87+
88+
data = {"state": state, "code": code}
89+
res = fastapi_client.post("/auth/callback/apple", data=data)
90+
91+
assert res.status_code == 303
92+
assert res.content == b""
93+
assert (
94+
res.headers["location"]
95+
== f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}"
96+
)
97+
98+
99+
async def exchange_auth_code_for_valid_oauth_tokens( # pylint: disable=unused-argument
100+
redirect_uri_info: RedirectUriInfo,
101+
user_context: Dict[str, Any],
102+
) -> Dict[str, Any]:
103+
return {
104+
"access_token": "accesstoken",
105+
"id_token": "idtoken",
106+
}
107+
108+
109+
async def get_user_info( # pylint: disable=unused-argument
110+
oauth_tokens: Dict[str, Any],
111+
user_context: Dict[str, Any],
112+
) -> UserInfo:
113+
time = str(datetime.datetime.now())
114+
return UserInfo(
115+
"" + time,
116+
UserInfoEmail(f"johndoeprovidertest+{time}@supertokens.com", True),
117+
RawUserInfoFromProvider({}, {}),
118+
)
119+
120+
121+
async def exchange_auth_code_for_invalid_oauth_tokens( # pylint: disable=unused-argument
122+
redirect_uri_info: RedirectUriInfo,
123+
user_context: Dict[str, Any],
124+
) -> Dict[str, Any]:
125+
return {
126+
"access_token": "wrongaccesstoken",
127+
"id_token": "wrongidtoken",
128+
}
129+
130+
131+
def get_custom_invalid_token_provider(provider: Provider) -> Provider:
132+
provider.exchange_auth_code_for_oauth_tokens = (
133+
exchange_auth_code_for_invalid_oauth_tokens
134+
)
135+
return provider
136+
137+
138+
def get_custom_valid_token_provider(provider: Provider) -> Provider:
139+
provider.exchange_auth_code_for_oauth_tokens = (
140+
exchange_auth_code_for_valid_oauth_tokens
141+
)
142+
provider.get_user_info = get_user_info
143+
return provider
144+
145+
146+
async def invalid_access_token( # pylint: disable=unused-argument
147+
access_token: str,
148+
config: ProviderConfigForClient,
149+
user_context: Optional[Dict[str, Any]],
150+
):
151+
if access_token == "wrongaccesstoken":
152+
raise Exception("Invalid access token")
153+
154+
155+
async def valid_access_token( # pylint: disable=unused-argument
156+
access_token: str, config: ProviderConfig, user_context: Optional[Dict[str, Any]]
157+
):
158+
if access_token == "accesstoken":
159+
return
160+
raise Exception("Unexpected access token")
161+
162+
163+
async def test_signinup_when_validate_access_token_throws(fastapi_client: TestClient):
164+
st_init_args = {
165+
**st_init_common_args,
166+
"recipe_list": [
167+
session.init(),
168+
thirdpartyemailpassword.init(
169+
providers=[
170+
ProviderInput(
171+
config=ProviderConfig(
172+
third_party_id="custom",
173+
clients=[
174+
ProviderClientConfig(
175+
client_id="test",
176+
client_secret="test-secret",
177+
scope=["profile", "email"],
178+
),
179+
],
180+
authorization_endpoint="https://example.com/oauth/authorize",
181+
validate_access_token=invalid_access_token,
182+
authorization_endpoint_query_params={
183+
"response_type": "token", # Changing an existing parameter
184+
"response_mode": "form", # Adding a new parameter
185+
"scope": None, # Removing a parameter
186+
},
187+
token_endpoint="https://example.com/oauth/token",
188+
),
189+
override=get_custom_invalid_token_provider,
190+
)
191+
]
192+
),
193+
],
194+
}
195+
init(**st_init_args) # type: ignore
196+
start_st()
197+
198+
res = fastapi_client.post(
199+
"/auth/signinup",
200+
json={
201+
"thirdPartyId": "custom",
202+
"redirectURIInfo": {
203+
"redirectURIOnProviderDashboard": "http://127.0.0.1/callback",
204+
"redirectURIQueryParams": {
205+
"code": "abcdefghj",
206+
},
207+
},
208+
},
209+
)
210+
assert res.status_code == 500
211+
212+
213+
# async def test_signinup_works_when_validate_access_token_does_not_throw(fastapi_client: TestClient):
214+
# st_init_args = {
215+
# **st_init_common_args,
216+
# "recipe_list": [
217+
# session.init(),
218+
# thirdpartyemailpassword.init(
219+
# providers=[
220+
# ProviderInput(
221+
# config=ProviderConfig(
222+
# third_party_id="custom",
223+
# clients=[
224+
# ProviderClientConfig(
225+
# client_id="test",
226+
# client_secret="test-secret",
227+
# scope=["profile", "email"],
228+
# ),
229+
# ],
230+
# authorization_endpoint="https://example.com/oauth/authorize",
231+
# validate_access_token=valid_access_token,
232+
# authorization_endpoint_query_params={
233+
# "response_type": "token", # Changing an existing parameter
234+
# "response_mode": "form", # Adding a new parameter
235+
# "scope": None, # Removing a parameter
236+
# },
237+
# token_endpoint="https://example.com/oauth/token",
238+
# ),
239+
# override=get_custom_valid_token_provider
240+
# )
241+
# ]
242+
# ),
243+
# ],
244+
# }
245+
#
246+
# init(**st_init_args) # type: ignore
247+
# start_st()
248+
#
249+
# res = fastapi_client.post(
250+
# "/auth/signinup",
251+
# json={
252+
# "thirdPartyId": "custom",
253+
# "redirectURIInfo": {
254+
# "redirectURIOnProviderDashboard": "http://127.0.0.1/callback",
255+
# "redirectURIQueryParams": {
256+
# "code": "abcdefghj",
257+
# },
258+
# },
259+
# }
260+
# )
261+
# assert res.status_code == 200
262+
# assert res.json()["status"] == "OK"

0 commit comments

Comments
 (0)