Skip to content

Commit c1d7086

Browse files
committed
feat: Add 429 rate limting from SaaS
1 parent 9e60f3d commit c1d7086

File tree

3 files changed

+180
-3
lines changed

3 files changed

+180
-3
lines changed

supertokens_python/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@
3939
API_VERSION = "/apiversion"
4040
API_VERSION_HEADER = "cdi-version"
4141
DASHBOARD_VERSION = "0.6"
42+
RATE_LIMIT_STATUS_CODE = 429

supertokens_python/querier.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# under the License.
1414
from __future__ import annotations
1515

16+
import asyncio
17+
1618
from json import JSONDecodeError
1719
from os import environ
18-
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict
20+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
1921

2022
from httpx import AsyncClient, ConnectTimeout, NetworkError, Response
2123

@@ -25,6 +27,7 @@
2527
API_VERSION_HEADER,
2628
RID_KEY_HEADER,
2729
SUPPORTED_CDI_VERSIONS,
30+
RATE_LIMIT_STATUS_CODE,
2831
)
2932
from .normalised_url_path import NormalisedURLPath
3033

@@ -203,6 +206,7 @@ async def __send_request_helper(
203206
method: str,
204207
http_function: Callable[[str], Awaitable[Response]],
205208
no_of_tries: int,
209+
retry_info_map: Optional[Dict[str, int]] = None,
206210
) -> Any:
207211
if no_of_tries == 0:
208212
raise_general_exception("No SuperTokens core available to query")
@@ -219,6 +223,14 @@ async def __send_request_helper(
219223
Querier.__last_tried_index %= len(self.__hosts)
220224
url = current_host + path.get_as_string_dangerous()
221225

226+
max_retries = 5
227+
228+
if retry_info_map is None:
229+
retry_info_map = {}
230+
231+
if retry_info_map.get(url) is None:
232+
retry_info_map[url] = max_retries
233+
222234
ProcessState.get_instance().add_state(
223235
AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER
224236
)
@@ -228,6 +240,20 @@ async def __send_request_helper(
228240
):
229241
Querier.__hosts_alive_for_testing.add(current_host)
230242

243+
if response.status_code == RATE_LIMIT_STATUS_CODE:
244+
retries_left = retry_info_map[url]
245+
246+
if retries_left > 0:
247+
retry_info_map[url] = retries_left - 1
248+
249+
attempts_made = max_retries - retries_left
250+
delay = (10 + attempts_made * 250) / 1000
251+
252+
await asyncio.sleep(delay)
253+
return await self.__send_request_helper(
254+
path, method, http_function, no_of_tries, retry_info_map
255+
)
256+
231257
if is_4xx_error(response.status_code) or is_5xx_error(response.status_code): # type: ignore
232258
raise_general_exception(
233259
"SuperTokens core threw an error for a "
@@ -245,9 +271,9 @@ async def __send_request_helper(
245271
except JSONDecodeError:
246272
return response.text
247273

248-
except (ConnectionError, NetworkError, ConnectTimeout):
274+
except (ConnectionError, NetworkError, ConnectTimeout) as _:
249275
return await self.__send_request_helper(
250-
path, method, http_function, no_of_tries - 1
276+
path, method, http_function, no_of_tries - 1, retry_info_map
251277
)
252278
except Exception as e:
253279
raise_general_exception(e)

tests/test_querier.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
from pytest import mark
15+
from supertokens_python.recipe import (
16+
session,
17+
emailpassword,
18+
emailverification,
19+
dashboard,
20+
)
21+
import asyncio
22+
import respx
23+
import httpx
24+
from supertokens_python import init, SupertokensConfig
25+
from supertokens_python.querier import Querier, NormalisedURLPath
26+
27+
from tests.utils import get_st_init_args
28+
from tests.utils import (
29+
setup_function,
30+
teardown_function,
31+
start_st,
32+
)
33+
34+
_ = setup_function
35+
_ = teardown_function
36+
37+
pytestmark = mark.asyncio
38+
respx_mock = respx.MockRouter
39+
40+
41+
async def test_network_call_is_retried_as_expected():
42+
# Test that network call is retried properly
43+
# Test that rate limiting errors are thrown back to the user
44+
args = get_st_init_args(
45+
[
46+
session.init(),
47+
emailpassword.init(),
48+
emailverification.init(mode="OPTIONAL"),
49+
dashboard.init(),
50+
]
51+
)
52+
args["supertokens_config"] = SupertokensConfig("http://localhost:6789")
53+
init(**args) # type: ignore
54+
start_st()
55+
56+
Querier.api_version = "3.0"
57+
q = Querier.get_instance()
58+
59+
api2_call_count = 0
60+
61+
def api2_side_effect(_: httpx.Request):
62+
nonlocal api2_call_count
63+
api2_call_count += 1
64+
65+
if api2_call_count == 3:
66+
return httpx.Response(200)
67+
68+
return httpx.Response(429, json={})
69+
70+
with respx_mock() as mocker:
71+
api1 = mocker.get("http://localhost:6789/api1").mock(
72+
httpx.Response(429, json={"status": "RATE_ERROR"})
73+
)
74+
api2 = mocker.get("http://localhost:6789/api2").mock(
75+
side_effect=api2_side_effect
76+
)
77+
api3 = mocker.get("http://localhost:6789/api3").mock(httpx.Response(200))
78+
79+
try:
80+
await q.send_get_request(NormalisedURLPath("/api1"), {})
81+
except Exception as e:
82+
if "with status code: 429" in str(
83+
e
84+
) and 'message: {"status": "RATE_ERROR"}' in str(e):
85+
pass
86+
else:
87+
raise e
88+
89+
await q.send_get_request(NormalisedURLPath("/api2"), {})
90+
await q.send_get_request(NormalisedURLPath("/api3"), {})
91+
92+
# 1 initial request + 5 retries
93+
assert api1.call_count == 6
94+
# 2 403 and 1 200
95+
assert api2.call_count == 3
96+
# 200 in the first attempt
97+
assert api3.call_count == 1
98+
99+
100+
async def test_parallel_calls_have_independent_counters():
101+
args = get_st_init_args(
102+
[
103+
session.init(),
104+
emailpassword.init(),
105+
emailverification.init(mode="OPTIONAL"),
106+
dashboard.init(),
107+
]
108+
)
109+
init(**args) # type: ignore
110+
start_st()
111+
112+
Querier.api_version = "3.0"
113+
q = Querier.get_instance()
114+
115+
call_count1 = 0
116+
call_count2 = 0
117+
118+
def api_side_effect(r: httpx.Request):
119+
nonlocal call_count1, call_count2
120+
121+
id_ = int(r.url.params.get("id"))
122+
if id_ == 1:
123+
call_count1 += 1
124+
elif id_ == 2:
125+
call_count2 += 1
126+
127+
return httpx.Response(429, json={})
128+
129+
with respx_mock() as mocker:
130+
api = mocker.get("http://localhost:3567/api").mock(side_effect=api_side_effect)
131+
132+
async def call_api(id_: int):
133+
try:
134+
await q.send_get_request(NormalisedURLPath("/api"), {"id": id_})
135+
except Exception as e:
136+
if "with status code: 429" in str(e):
137+
pass
138+
else:
139+
raise e
140+
141+
_ = await asyncio.gather(
142+
call_api(1),
143+
call_api(2),
144+
)
145+
146+
# 1 initial request + 5 retries
147+
assert call_count1 == 6
148+
assert call_count2 == 6
149+
150+
assert api.call_count == 12

0 commit comments

Comments
 (0)