Skip to content

Commit 01e7e43

Browse files
Merge pull request #432 from supertokens/feat/rate-limiting-0.12
feat: Add 429 rate limting from SaaS for v0.12
2 parents 9e60f3d + d85ba34 commit 01e7e43

File tree

6 files changed

+196
-22
lines changed

6 files changed

+196
-22
lines changed

CHANGELOG.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## unreleased
99

10+
## [0.12.10] - 2023-09-28
11+
12+
- Add logic to retry network calls if the core returns status 429
1013

1114
## [0.12.9] - 2023-04-28
1215

@@ -267,7 +270,7 @@ init(
267270
def verify_email_for_passwordless_users():
268271
pagination_token = None
269272
done = False
270-
273+
271274
while not done:
272275
res = get_users_newest_first(
273276
limit=100,
@@ -280,7 +283,7 @@ def verify_email_for_passwordless_users():
280283
token_res = create_email_verification_token(user.user_id, user.email)
281284
if isinstance(token_res, CreateEmailVerificationTokenOkResult):
282285
verify_email_using_token(token_res.token)
283-
286+
284287
done = res.next_pagination_token is None
285288
if not done:
286289
pagination_token = res.next_pagination_token
@@ -310,7 +313,7 @@ The `UserRoles` recipe now adds role and permission information into the access
310313

311314
## [0.10.2] - 2022-07-14
312315
### Bug fix
313-
- Make `user_context` optional in userroles recipe syncio functions.
316+
- Make `user_context` optional in userroles recipe syncio functions.
314317

315318
## [0.10.1] - 2022-07-11
316319

@@ -845,4 +848,4 @@ init(
845848
- Middleware, error handlers and verify session for each framework.
846849
- Created a wrapper for async to sync for supporting older version of python web frameworks.
847850
- Base tests for each framework.
848-
- New requirements in the setup file.
851+
- New requirements in the setup file.

addDevTag

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
#!/bin/bash
22

3-
# check if we need to merge master into this branch------------
4-
if [[ $(git log origin/master ^HEAD) ]]; then
5-
echo "You need to merge master into this branch. Exiting"
6-
exit 1
7-
fi
8-
93
# get version------------
104
version=`cat setup.py | grep -e 'version='`
115
while IFS='"' read -ra ADDR; do
@@ -86,4 +80,4 @@ fi
8680

8781
git tag dev-v$version $commit_hash
8882

89-
git push --tags
83+
git push --tags

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070

7171
setup(
7272
name="supertokens_python",
73-
version="0.12.9",
73+
version="0.12.10",
7474
author="SuperTokens",
7575
license="Apache 2.0",
7676
author_email="[email protected]",

supertokens_python/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"2.19",
2626
"2.20",
2727
]
28-
VERSION = "0.12.9"
28+
VERSION = "0.12.10"
2929
TELEMETRY = "/telemetry"
3030
USER_COUNT = "/users/count"
3131
USER_DELETE = "/user/remove"
@@ -39,3 +39,4 @@
3939
API_VERSION = "/apiversion"
4040
API_VERSION_HEADER = "cdi-version"
4141
DASHBOARD_VERSION = "0.6"
42+
RATE_LIMIT_STATUS_CODE = 429

supertokens_python/querier.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# under the License.
1414
from __future__ import annotations
1515

16+
import asyncio
17+
1618
from json import JSONDecodeError
1719
from os import environ
18-
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict
20+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
1921

2022
from httpx import AsyncClient, ConnectTimeout, NetworkError, Response
2123

@@ -25,6 +27,7 @@
2527
API_VERSION_HEADER,
2628
RID_KEY_HEADER,
2729
SUPPORTED_CDI_VERSIONS,
30+
RATE_LIMIT_STATUS_CODE,
2831
)
2932
from .normalised_url_path import NormalisedURLPath
3033

@@ -42,7 +45,7 @@ class Querier:
4245
__init_called = False
4346
__hosts: List[Host] = []
4447
__api_key: Union[None, str] = None
45-
__api_version = None
48+
api_version = None
4649
__last_tried_index: int = 0
4750
__hosts_alive_for_testing: Set[str] = set()
4851

@@ -69,8 +72,8 @@ def get_hosts_alive_for_testing():
6972
return Querier.__hosts_alive_for_testing
7073

7174
async def get_api_version(self):
72-
if Querier.__api_version is not None:
73-
return Querier.__api_version
75+
if Querier.api_version is not None:
76+
return Querier.api_version
7477

7578
ProcessState.get_instance().add_state(
7679
AllowedProcessStates.CALLING_SERVICE_IN_GET_API_VERSION
@@ -96,8 +99,8 @@ async def f(url: str) -> Response:
9699
"to find the right versions"
97100
)
98101

99-
Querier.__api_version = api_version
100-
return Querier.__api_version
102+
Querier.api_version = api_version
103+
return Querier.api_version
101104

102105
@staticmethod
103106
def get_instance(rid_to_core: Union[str, None] = None):
@@ -113,7 +116,7 @@ def init(hosts: List[Host], api_key: Union[str, None] = None):
113116
Querier.__init_called = True
114117
Querier.__hosts = hosts
115118
Querier.__api_key = api_key
116-
Querier.__api_version = None
119+
Querier.api_version = None
117120
Querier.__last_tried_index = 0
118121
Querier.__hosts_alive_for_testing = set()
119122

@@ -203,6 +206,7 @@ async def __send_request_helper(
203206
method: str,
204207
http_function: Callable[[str], Awaitable[Response]],
205208
no_of_tries: int,
209+
retry_info_map: Optional[Dict[str, int]] = None,
206210
) -> Any:
207211
if no_of_tries == 0:
208212
raise_general_exception("No SuperTokens core available to query")
@@ -219,6 +223,14 @@ async def __send_request_helper(
219223
Querier.__last_tried_index %= len(self.__hosts)
220224
url = current_host + path.get_as_string_dangerous()
221225

226+
max_retries = 5
227+
228+
if retry_info_map is None:
229+
retry_info_map = {}
230+
231+
if retry_info_map.get(url) is None:
232+
retry_info_map[url] = max_retries
233+
222234
ProcessState.get_instance().add_state(
223235
AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER
224236
)
@@ -228,6 +240,20 @@ async def __send_request_helper(
228240
):
229241
Querier.__hosts_alive_for_testing.add(current_host)
230242

243+
if response.status_code == RATE_LIMIT_STATUS_CODE:
244+
retries_left = retry_info_map[url]
245+
246+
if retries_left > 0:
247+
retry_info_map[url] = retries_left - 1
248+
249+
attempts_made = max_retries - retries_left
250+
delay = (10 + attempts_made * 250) / 1000
251+
252+
await asyncio.sleep(delay)
253+
return await self.__send_request_helper(
254+
path, method, http_function, no_of_tries, retry_info_map
255+
)
256+
231257
if is_4xx_error(response.status_code) or is_5xx_error(response.status_code): # type: ignore
232258
raise_general_exception(
233259
"SuperTokens core threw an error for a "
@@ -245,9 +271,9 @@ async def __send_request_helper(
245271
except JSONDecodeError:
246272
return response.text
247273

248-
except (ConnectionError, NetworkError, ConnectTimeout):
274+
except (ConnectionError, NetworkError, ConnectTimeout) as _:
249275
return await self.__send_request_helper(
250-
path, method, http_function, no_of_tries - 1
276+
path, method, http_function, no_of_tries - 1, retry_info_map
251277
)
252278
except Exception as e:
253279
raise_general_exception(e)

tests/test_querier.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
from pytest import mark
15+
from supertokens_python.recipe import (
16+
session,
17+
emailpassword,
18+
emailverification,
19+
dashboard,
20+
)
21+
import asyncio
22+
import respx
23+
import httpx
24+
from supertokens_python import init, SupertokensConfig
25+
from supertokens_python.querier import Querier, NormalisedURLPath
26+
27+
from tests.utils import get_st_init_args
28+
from tests.utils import (
29+
setup_function,
30+
teardown_function,
31+
start_st,
32+
)
33+
34+
_ = setup_function
35+
_ = teardown_function
36+
37+
pytestmark = mark.asyncio
38+
respx_mock = respx.MockRouter
39+
40+
41+
async def test_network_call_is_retried_as_expected():
42+
# Test that network call is retried properly
43+
# Test that rate limiting errors are thrown back to the user
44+
args = get_st_init_args(
45+
[
46+
session.init(),
47+
emailpassword.init(),
48+
emailverification.init(mode="OPTIONAL"),
49+
dashboard.init(),
50+
]
51+
)
52+
args["supertokens_config"] = SupertokensConfig("http://localhost:6789")
53+
init(**args) # type: ignore
54+
start_st()
55+
56+
Querier.api_version = "3.0"
57+
q = Querier.get_instance()
58+
59+
api2_call_count = 0
60+
61+
def api2_side_effect(_: httpx.Request):
62+
nonlocal api2_call_count
63+
api2_call_count += 1
64+
65+
if api2_call_count == 3:
66+
return httpx.Response(200)
67+
68+
return httpx.Response(429, json={})
69+
70+
with respx_mock() as mocker:
71+
api1 = mocker.get("http://localhost:6789/api1").mock(
72+
httpx.Response(429, json={"status": "RATE_ERROR"})
73+
)
74+
api2 = mocker.get("http://localhost:6789/api2").mock(
75+
side_effect=api2_side_effect
76+
)
77+
api3 = mocker.get("http://localhost:6789/api3").mock(httpx.Response(200))
78+
79+
try:
80+
await q.send_get_request(NormalisedURLPath("/api1"), {})
81+
except Exception as e:
82+
if "with status code: 429" in str(
83+
e
84+
) and 'message: {"status": "RATE_ERROR"}' in str(e):
85+
pass
86+
else:
87+
raise e
88+
89+
await q.send_get_request(NormalisedURLPath("/api2"), {})
90+
await q.send_get_request(NormalisedURLPath("/api3"), {})
91+
92+
# 1 initial request + 5 retries
93+
assert api1.call_count == 6
94+
# 2 403 and 1 200
95+
assert api2.call_count == 3
96+
# 200 in the first attempt
97+
assert api3.call_count == 1
98+
99+
100+
async def test_parallel_calls_have_independent_counters():
101+
args = get_st_init_args(
102+
[
103+
session.init(),
104+
emailpassword.init(),
105+
emailverification.init(mode="OPTIONAL"),
106+
dashboard.init(),
107+
]
108+
)
109+
init(**args) # type: ignore
110+
start_st()
111+
112+
Querier.api_version = "3.0"
113+
q = Querier.get_instance()
114+
115+
call_count1 = 0
116+
call_count2 = 0
117+
118+
def api_side_effect(r: httpx.Request):
119+
nonlocal call_count1, call_count2
120+
121+
id_ = int(r.url.params.get("id"))
122+
if id_ == 1:
123+
call_count1 += 1
124+
elif id_ == 2:
125+
call_count2 += 1
126+
127+
return httpx.Response(429, json={})
128+
129+
with respx_mock() as mocker:
130+
api = mocker.get("http://localhost:3567/api").mock(side_effect=api_side_effect)
131+
132+
async def call_api(id_: int):
133+
try:
134+
await q.send_get_request(NormalisedURLPath("/api"), {"id": id_})
135+
except Exception as e:
136+
if "with status code: 429" in str(e):
137+
pass
138+
else:
139+
raise e
140+
141+
_ = await asyncio.gather(
142+
call_api(1),
143+
call_api(2),
144+
)
145+
146+
# 1 initial request + 5 retries
147+
assert call_count1 == 6
148+
assert call_count2 == 6
149+
150+
assert api.call_count == 12

0 commit comments

Comments
 (0)