Skip to content

Commit 8bb55c6

Browse files
Merge pull request #428 from supertokens/feat/rate-limting
feat: Add 429 rate limting from SaaS
2 parents a3f023b + 61bb182 commit 8bb55c6

File tree

6 files changed

+191
-6
lines changed

6 files changed

+191
-6
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
## [unreleased]
1010

11+
12+
13+
## [0.15.3] - 2023-09-24
14+
15+
- Handle 429 rate limiting from SaaS core instances
16+
1117
## [0.15.2] - 2023-09-23
1218

1319
- Fixed bugs in thirdparty providers: Bitbucket, Boxy-SAML, and Facebook

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070

7171
setup(
7272
name="supertokens_python",
73-
version="0.15.2",
73+
version="0.15.3",
7474
author="SuperTokens",
7575
license="Apache 2.0",
7676
author_email="[email protected]",

supertokens_python/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
SUPPORTED_CDI_VERSIONS = ["3.0"]
17-
VERSION = "0.15.2"
17+
VERSION = "0.15.3"
1818
TELEMETRY = "/telemetry"
1919
USER_COUNT = "/users/count"
2020
USER_DELETE = "/user/remove"
@@ -29,3 +29,4 @@
2929
API_VERSION_HEADER = "cdi-version"
3030
DASHBOARD_VERSION = "0.7"
3131
HUNDRED_YEARS_IN_MS = 3153600000000
32+
RATE_LIMIT_STATUS_CODE = 429

supertokens_python/normalised_url_path.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ def equals(self, other: NormalisedURLPath) -> bool:
4040

4141
def is_a_recipe_path(self) -> bool:
4242
parts = self.__value.split("/")
43-
return parts[1] == "recipe" or parts[2] == "recipe"
43+
return (len(parts) > 1 and parts[1] == "recipe") or (
44+
len(parts) > 2 and parts[2] == "recipe"
45+
)
4446

4547

4648
def normalise_url_path_or_throw_error(input_str: str) -> str:

supertokens_python/querier.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# under the License.
1414
from __future__ import annotations
1515

16+
import asyncio
17+
1618
from json import JSONDecodeError
1719
from os import environ
18-
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict
20+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
1921

2022
from httpx import AsyncClient, ConnectTimeout, NetworkError, Response
2123

@@ -25,6 +27,7 @@
2527
API_VERSION_HEADER,
2628
RID_KEY_HEADER,
2729
SUPPORTED_CDI_VERSIONS,
30+
RATE_LIMIT_STATUS_CODE,
2831
)
2932
from .normalised_url_path import NormalisedURLPath
3033

@@ -222,6 +225,7 @@ async def __send_request_helper(
222225
method: str,
223226
http_function: Callable[[str], Awaitable[Response]],
224227
no_of_tries: int,
228+
retry_info_map: Optional[Dict[str, int]] = None,
225229
) -> Any:
226230
if no_of_tries == 0:
227231
raise_general_exception("No SuperTokens core available to query")
@@ -238,6 +242,14 @@ async def __send_request_helper(
238242
Querier.__last_tried_index %= len(self.__hosts)
239243
url = current_host + path.get_as_string_dangerous()
240244

245+
max_retries = 5
246+
247+
if retry_info_map is None:
248+
retry_info_map = {}
249+
250+
if retry_info_map.get(url) is None:
251+
retry_info_map[url] = max_retries
252+
241253
ProcessState.get_instance().add_state(
242254
AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER
243255
)
@@ -247,6 +259,20 @@ async def __send_request_helper(
247259
):
248260
Querier.__hosts_alive_for_testing.add(current_host)
249261

262+
if response.status_code == RATE_LIMIT_STATUS_CODE:
263+
retries_left = retry_info_map[url]
264+
265+
if retries_left > 0:
266+
retry_info_map[url] = retries_left - 1
267+
268+
attempts_made = max_retries - retries_left
269+
delay = (10 + attempts_made * 250) / 1000
270+
271+
await asyncio.sleep(delay)
272+
return await self.__send_request_helper(
273+
path, method, http_function, no_of_tries, retry_info_map
274+
)
275+
250276
if is_4xx_error(response.status_code) or is_5xx_error(response.status_code): # type: ignore
251277
raise_general_exception(
252278
"SuperTokens core threw an error for a "
@@ -264,9 +290,9 @@ async def __send_request_helper(
264290
except JSONDecodeError:
265291
return response.text
266292

267-
except (ConnectionError, NetworkError, ConnectTimeout):
293+
except (ConnectionError, NetworkError, ConnectTimeout) as _:
268294
return await self.__send_request_helper(
269-
path, method, http_function, no_of_tries - 1
295+
path, method, http_function, no_of_tries - 1, retry_info_map
270296
)
271297
except Exception as e:
272298
raise_general_exception(e)

tests/test_querier.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
from pytest import mark
15+
from supertokens_python.recipe import (
16+
session,
17+
emailpassword,
18+
emailverification,
19+
dashboard,
20+
)
21+
import asyncio
22+
import respx
23+
import httpx
24+
from supertokens_python import init, SupertokensConfig
25+
from supertokens_python.querier import Querier, NormalisedURLPath
26+
27+
from tests.utils import get_st_init_args
28+
from tests.utils import (
29+
setup_function,
30+
teardown_function,
31+
start_st,
32+
)
33+
34+
_ = setup_function
35+
_ = teardown_function
36+
37+
pytestmark = mark.asyncio
38+
respx_mock = respx.MockRouter
39+
40+
41+
async def test_network_call_is_retried_as_expected():
42+
# Test that network call is retried properly
43+
# Test that rate limiting errors are thrown back to the user
44+
args = get_st_init_args(
45+
[
46+
session.init(),
47+
emailpassword.init(),
48+
emailverification.init(mode="OPTIONAL"),
49+
dashboard.init(),
50+
]
51+
)
52+
args["supertokens_config"] = SupertokensConfig("http://localhost:6789")
53+
init(**args) # type: ignore
54+
start_st()
55+
56+
Querier.api_version = "3.0"
57+
q = Querier.get_instance()
58+
59+
api2_call_count = 0
60+
61+
def api2_side_effect(_: httpx.Request):
62+
nonlocal api2_call_count
63+
api2_call_count += 1
64+
65+
if api2_call_count == 3:
66+
return httpx.Response(200)
67+
68+
return httpx.Response(429, json={})
69+
70+
with respx_mock() as mocker:
71+
api1 = mocker.get("http://localhost:6789/api1").mock(
72+
httpx.Response(429, json={"status": "RATE_ERROR"})
73+
)
74+
api2 = mocker.get("http://localhost:6789/api2").mock(
75+
side_effect=api2_side_effect
76+
)
77+
api3 = mocker.get("http://localhost:6789/api3").mock(httpx.Response(200))
78+
79+
try:
80+
await q.send_get_request(NormalisedURLPath("/api1"), {})
81+
except Exception as e:
82+
if "with status code: 429" in str(
83+
e
84+
) and 'message: {"status": "RATE_ERROR"}' in str(e):
85+
pass
86+
else:
87+
raise e
88+
89+
await q.send_get_request(NormalisedURLPath("/api2"), {})
90+
await q.send_get_request(NormalisedURLPath("/api3"), {})
91+
92+
# 1 initial request + 5 retries
93+
assert api1.call_count == 6
94+
# 2 403 and 1 200
95+
assert api2.call_count == 3
96+
# 200 in the first attempt
97+
assert api3.call_count == 1
98+
99+
100+
async def test_parallel_calls_have_independent_counters():
101+
args = get_st_init_args(
102+
[
103+
session.init(),
104+
emailpassword.init(),
105+
emailverification.init(mode="OPTIONAL"),
106+
dashboard.init(),
107+
]
108+
)
109+
init(**args) # type: ignore
110+
start_st()
111+
112+
Querier.api_version = "3.0"
113+
q = Querier.get_instance()
114+
115+
call_count1 = 0
116+
call_count2 = 0
117+
118+
def api_side_effect(r: httpx.Request):
119+
nonlocal call_count1, call_count2
120+
121+
id_ = int(r.url.params.get("id"))
122+
if id_ == 1:
123+
call_count1 += 1
124+
elif id_ == 2:
125+
call_count2 += 1
126+
127+
return httpx.Response(429, json={})
128+
129+
with respx_mock() as mocker:
130+
api = mocker.get("http://localhost:3567/api").mock(side_effect=api_side_effect)
131+
132+
async def call_api(id_: int):
133+
try:
134+
await q.send_get_request(NormalisedURLPath("/api"), {"id": id_})
135+
except Exception as e:
136+
if "with status code: 429" in str(e):
137+
pass
138+
else:
139+
raise e
140+
141+
_ = await asyncio.gather(
142+
call_api(1),
143+
call_api(2),
144+
)
145+
146+
# 1 initial request + 5 retries
147+
assert call_count1 == 6
148+
assert call_count2 == 6
149+
150+
assert api.call_count == 12

0 commit comments

Comments
 (0)