Skip to content

Commit f2c8d51

Browse files
Merge pull request #430 from supertokens/feat/rate-limiting-0.14
feat: Add 429 rate limting from SaaS for v0.14
2 parents e6e798d + 4d67c04 commit f2c8d51

File tree

5 files changed

+187
-6
lines changed

5 files changed

+187
-6
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
## [unreleased]
1010

11+
## [0.14.9] - 2023-09-28
12+
13+
- Add logic to retry network calls if the core returns status 429
14+
1115
## [0.14.8] - 2023-07-07
1216
## Fixes
1317

@@ -148,7 +152,7 @@ if (accessTokenPayload.sub !== undefined) {
148152
```python
149153
from supertokens_python.recipe.session.interfaces import SessionContainer
150154

151-
session: SessionContainer = ...
155+
session: SessionContainer = ...
152156
access_token_payload = await session.get_access_token_payload()
153157

154158
if access_token_payload.get('sub') is not None:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070

7171
setup(
7272
name="supertokens_python",
73-
version="0.14.8",
73+
version="0.14.9",
7474
author="SuperTokens",
7575
license="Apache 2.0",
7676
author_email="[email protected]",

supertokens_python/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
SUPPORTED_CDI_VERSIONS = ["2.21"]
17-
VERSION = "0.14.8"
17+
VERSION = "0.14.9"
1818
TELEMETRY = "/telemetry"
1919
USER_COUNT = "/users/count"
2020
USER_DELETE = "/user/remove"
@@ -29,3 +29,4 @@
2929
API_VERSION_HEADER = "cdi-version"
3030
DASHBOARD_VERSION = "0.6"
3131
HUNDRED_YEARS_IN_MS = 3153600000000
32+
RATE_LIMIT_STATUS_CODE = 429

supertokens_python/querier.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# under the License.
1414
from __future__ import annotations
1515

16+
import asyncio
17+
1618
from json import JSONDecodeError
1719
from os import environ
18-
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict
20+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
1921

2022
from httpx import AsyncClient, ConnectTimeout, NetworkError, Response
2123

@@ -25,6 +27,7 @@
2527
API_VERSION_HEADER,
2628
RID_KEY_HEADER,
2729
SUPPORTED_CDI_VERSIONS,
30+
RATE_LIMIT_STATUS_CODE,
2831
)
2932
from .normalised_url_path import NormalisedURLPath
3033

@@ -222,6 +225,7 @@ async def __send_request_helper(
222225
method: str,
223226
http_function: Callable[[str], Awaitable[Response]],
224227
no_of_tries: int,
228+
retry_info_map: Optional[Dict[str, int]] = None,
225229
) -> Any:
226230
if no_of_tries == 0:
227231
raise_general_exception("No SuperTokens core available to query")
@@ -238,6 +242,14 @@ async def __send_request_helper(
238242
Querier.__last_tried_index %= len(self.__hosts)
239243
url = current_host + path.get_as_string_dangerous()
240244

245+
max_retries = 5
246+
247+
if retry_info_map is None:
248+
retry_info_map = {}
249+
250+
if retry_info_map.get(url) is None:
251+
retry_info_map[url] = max_retries
252+
241253
ProcessState.get_instance().add_state(
242254
AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER
243255
)
@@ -247,6 +259,20 @@ async def __send_request_helper(
247259
):
248260
Querier.__hosts_alive_for_testing.add(current_host)
249261

262+
if response.status_code == RATE_LIMIT_STATUS_CODE:
263+
retries_left = retry_info_map[url]
264+
265+
if retries_left > 0:
266+
retry_info_map[url] = retries_left - 1
267+
268+
attempts_made = max_retries - retries_left
269+
delay = (10 + attempts_made * 250) / 1000
270+
271+
await asyncio.sleep(delay)
272+
return await self.__send_request_helper(
273+
path, method, http_function, no_of_tries, retry_info_map
274+
)
275+
250276
if is_4xx_error(response.status_code) or is_5xx_error(response.status_code): # type: ignore
251277
raise_general_exception(
252278
"SuperTokens core threw an error for a "
@@ -264,9 +290,9 @@ async def __send_request_helper(
264290
except JSONDecodeError:
265291
return response.text
266292

267-
except (ConnectionError, NetworkError, ConnectTimeout):
293+
except (ConnectionError, NetworkError, ConnectTimeout) as _:
268294
return await self.__send_request_helper(
269-
path, method, http_function, no_of_tries - 1
295+
path, method, http_function, no_of_tries - 1, retry_info_map
270296
)
271297
except Exception as e:
272298
raise_general_exception(e)

tests/test_querier.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
from pytest import mark
15+
from supertokens_python.recipe import (
16+
session,
17+
emailpassword,
18+
emailverification,
19+
dashboard,
20+
)
21+
import asyncio
22+
import respx
23+
import httpx
24+
from supertokens_python import init, SupertokensConfig
25+
from supertokens_python.querier import Querier, NormalisedURLPath
26+
27+
from tests.utils import get_st_init_args
28+
from tests.utils import (
29+
setup_function,
30+
teardown_function,
31+
start_st,
32+
)
33+
34+
_ = setup_function
35+
_ = teardown_function
36+
37+
pytestmark = mark.asyncio
38+
respx_mock = respx.MockRouter
39+
40+
41+
async def test_network_call_is_retried_as_expected():
42+
# Test that network call is retried properly
43+
# Test that rate limiting errors are thrown back to the user
44+
args = get_st_init_args(
45+
[
46+
session.init(),
47+
emailpassword.init(),
48+
emailverification.init(mode="OPTIONAL"),
49+
dashboard.init(),
50+
]
51+
)
52+
args["supertokens_config"] = SupertokensConfig("http://localhost:6789")
53+
init(**args) # type: ignore
54+
start_st()
55+
56+
Querier.api_version = "3.0"
57+
q = Querier.get_instance()
58+
59+
api2_call_count = 0
60+
61+
def api2_side_effect(_: httpx.Request):
62+
nonlocal api2_call_count
63+
api2_call_count += 1
64+
65+
if api2_call_count == 3:
66+
return httpx.Response(200)
67+
68+
return httpx.Response(429, json={})
69+
70+
with respx_mock() as mocker:
71+
api1 = mocker.get("http://localhost:6789/api1").mock(
72+
httpx.Response(429, json={"status": "RATE_ERROR"})
73+
)
74+
api2 = mocker.get("http://localhost:6789/api2").mock(
75+
side_effect=api2_side_effect
76+
)
77+
api3 = mocker.get("http://localhost:6789/api3").mock(httpx.Response(200))
78+
79+
try:
80+
await q.send_get_request(NormalisedURLPath("/api1"), {})
81+
except Exception as e:
82+
if "with status code: 429" in str(
83+
e
84+
) and 'message: {"status": "RATE_ERROR"}' in str(e):
85+
pass
86+
else:
87+
raise e
88+
89+
await q.send_get_request(NormalisedURLPath("/api2"), {})
90+
await q.send_get_request(NormalisedURLPath("/api3"), {})
91+
92+
# 1 initial request + 5 retries
93+
assert api1.call_count == 6
94+
# 2 403 and 1 200
95+
assert api2.call_count == 3
96+
# 200 in the first attempt
97+
assert api3.call_count == 1
98+
99+
100+
async def test_parallel_calls_have_independent_counters():
101+
args = get_st_init_args(
102+
[
103+
session.init(),
104+
emailpassword.init(),
105+
emailverification.init(mode="OPTIONAL"),
106+
dashboard.init(),
107+
]
108+
)
109+
init(**args) # type: ignore
110+
start_st()
111+
112+
Querier.api_version = "3.0"
113+
q = Querier.get_instance()
114+
115+
call_count1 = 0
116+
call_count2 = 0
117+
118+
def api_side_effect(r: httpx.Request):
119+
nonlocal call_count1, call_count2
120+
121+
id_ = int(r.url.params.get("id"))
122+
if id_ == 1:
123+
call_count1 += 1
124+
elif id_ == 2:
125+
call_count2 += 1
126+
127+
return httpx.Response(429, json={})
128+
129+
with respx_mock() as mocker:
130+
api = mocker.get("http://localhost:3567/api").mock(side_effect=api_side_effect)
131+
132+
async def call_api(id_: int):
133+
try:
134+
await q.send_get_request(NormalisedURLPath("/api"), {"id": id_})
135+
except Exception as e:
136+
if "with status code: 429" in str(e):
137+
pass
138+
else:
139+
raise e
140+
141+
_ = await asyncio.gather(
142+
call_api(1),
143+
call_api(2),
144+
)
145+
146+
# 1 initial request + 5 retries
147+
assert call_count1 == 6
148+
assert call_count2 == 6
149+
150+
assert api.call_count == 12

0 commit comments

Comments
 (0)