Skip to content

Commit 5e4c6a4

Browse files
Merge pull request #433 from supertokens/feat/rate-limiting-0.11
feat: Add 429 rate limting from SaaS for v0.11
2 parents a3dd752 + 99d8508 commit 5e4c6a4

File tree

7 files changed

+195
-24
lines changed

7 files changed

+195
-24
lines changed

CHANGELOG.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

88
## unreleased
9+
10+
# [0.11.14] - 2023-09-01
11+
12+
- Add logic to retry network calls if the core returns status 429
13+
14+
915
# [0.11.13] - 2023-01-06
1016

1117
- Add missing `original` attribute to flask response and remove logic for cases where `response` is `None`
@@ -205,7 +211,7 @@ init(
205211
def verify_email_for_passwordless_users():
206212
pagination_token = None
207213
done = False
208-
214+
209215
while not done:
210216
res = get_users_newest_first(
211217
limit=100,
@@ -218,7 +224,7 @@ def verify_email_for_passwordless_users():
218224
token_res = create_email_verification_token(user.user_id, user.email)
219225
if isinstance(token_res, CreateEmailVerificationTokenOkResult):
220226
verify_email_using_token(token_res.token)
221-
227+
222228
done = res.next_pagination_token is None
223229
if not done:
224230
pagination_token = res.next_pagination_token
@@ -248,7 +254,7 @@ The `UserRoles` recipe now adds role and permission information into the access
248254

249255
## [0.10.2] - 2022-07-14
250256
### Bug fix
251-
- Make `user_context` optional in userroles recipe syncio functions.
257+
- Make `user_context` optional in userroles recipe syncio functions.
252258

253259
## [0.10.1] - 2022-07-11
254260

@@ -783,4 +789,4 @@ init(
783789
- Middleware, error handlers and verify session for each framework.
784790
- Created a wrapper for async to sync for supporting older version of python web frameworks.
785791
- Base tests for each framework.
786-
- New requirements in the setup file.
792+
- New requirements in the setup file.

addDevTag

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
#!/bin/bash
22

3-
# check if we need to merge master into this branch------------
4-
if [[ $(git log origin/master ^HEAD) ]]; then
5-
echo "You need to merge master into this branch. Exiting"
6-
exit 1
7-
fi
8-
93
# get version------------
104
version=`cat setup.py | grep -e 'version='`
115
while IFS='"' read -ra ADDR; do
@@ -86,4 +80,4 @@ fi
8680

8781
git tag dev-v$version $commit_hash
8882

89-
git push --tags
83+
git push --tags

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070

7171
setup(
7272
name="supertokens_python",
73-
version="0.11.13",
73+
version="0.11.14",
7474
author="SuperTokens",
7575
license="Apache 2.0",
7676
author_email="[email protected]",

supertokens_python/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414
SUPPORTED_CDI_VERSIONS = ["2.9", "2.10", "2.11", "2.12", "2.13", "2.14", "2.15"]
15-
VERSION = "0.11.13"
15+
VERSION = "0.11.14"
1616
TELEMETRY = "/telemetry"
1717
USER_COUNT = "/users/count"
1818
USER_DELETE = "/user/remove"
@@ -26,3 +26,4 @@
2626
API_VERSION = "/apiversion"
2727
API_VERSION_HEADER = "cdi-version"
2828
DASHBOARD_VERSION = "0.3"
29+
RATE_LIMIT_STATUS_CODE = 429

supertokens_python/querier.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# under the License.
1414
from __future__ import annotations
1515

16+
import asyncio
17+
1618
from json import JSONDecodeError
1719
from os import environ
18-
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict
20+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
1921

2022
from httpx import AsyncClient, ConnectTimeout, NetworkError, Response
2123

@@ -25,6 +27,7 @@
2527
API_VERSION_HEADER,
2628
RID_KEY_HEADER,
2729
SUPPORTED_CDI_VERSIONS,
30+
RATE_LIMIT_STATUS_CODE,
2831
)
2932
from .normalised_url_path import NormalisedURLPath
3033

@@ -42,7 +45,7 @@ class Querier:
4245
__init_called = False
4346
__hosts: List[Host] = []
4447
__api_key: Union[None, str] = None
45-
__api_version = None
48+
api_version = None
4649
__last_tried_index: int = 0
4750
__hosts_alive_for_testing: Set[str] = set()
4851

@@ -69,8 +72,8 @@ def get_hosts_alive_for_testing():
6972
return Querier.__hosts_alive_for_testing
7073

7174
async def get_api_version(self):
72-
if Querier.__api_version is not None:
73-
return Querier.__api_version
75+
if Querier.api_version is not None:
76+
return Querier.api_version
7477

7578
ProcessState.get_instance().add_state(
7679
AllowedProcessStates.CALLING_SERVICE_IN_GET_API_VERSION
@@ -96,8 +99,8 @@ async def f(url: str) -> Response:
9699
"to find the right versions"
97100
)
98101

99-
Querier.__api_version = api_version
100-
return Querier.__api_version
102+
Querier.api_version = api_version
103+
return Querier.api_version
101104

102105
@staticmethod
103106
def get_instance(rid_to_core: Union[str, None] = None):
@@ -113,7 +116,7 @@ def init(hosts: List[Host], api_key: Union[str, None] = None):
113116
Querier.__init_called = True
114117
Querier.__hosts = hosts
115118
Querier.__api_key = api_key
116-
Querier.__api_version = None
119+
Querier.api_version = None
117120
Querier.__last_tried_index = 0
118121
Querier.__hosts_alive_for_testing = set()
119122

@@ -196,6 +199,7 @@ async def __send_request_helper(
196199
method: str,
197200
http_function: Callable[[str], Awaitable[Response]],
198201
no_of_tries: int,
202+
retry_info_map: Optional[Dict[str, int]] = None,
199203
) -> Any:
200204
if no_of_tries == 0:
201205
raise_general_exception("No SuperTokens core available to query")
@@ -212,6 +216,14 @@ async def __send_request_helper(
212216
Querier.__last_tried_index %= len(self.__hosts)
213217
url = current_host + path.get_as_string_dangerous()
214218

219+
max_retries = 5
220+
221+
if retry_info_map is None:
222+
retry_info_map = {}
223+
224+
if retry_info_map.get(url) is None:
225+
retry_info_map[url] = max_retries
226+
215227
ProcessState.get_instance().add_state(
216228
AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER
217229
)
@@ -221,6 +233,20 @@ async def __send_request_helper(
221233
):
222234
Querier.__hosts_alive_for_testing.add(current_host)
223235

236+
if response.status_code == RATE_LIMIT_STATUS_CODE:
237+
retries_left = retry_info_map[url]
238+
239+
if retries_left > 0:
240+
retry_info_map[url] = retries_left - 1
241+
242+
attempts_made = max_retries - retries_left
243+
delay = (10 + attempts_made * 250) / 1000
244+
245+
await asyncio.sleep(delay)
246+
return await self.__send_request_helper(
247+
path, method, http_function, no_of_tries, retry_info_map
248+
)
249+
224250
if is_4xx_error(response.status_code) or is_5xx_error(response.status_code): # type: ignore
225251
raise_general_exception(
226252
"SuperTokens core threw an error for a "
@@ -238,9 +264,9 @@ async def __send_request_helper(
238264
except JSONDecodeError:
239265
return response.text
240266

241-
except (ConnectionError, NetworkError, ConnectTimeout):
267+
except (ConnectionError, NetworkError, ConnectTimeout) as _:
242268
return await self.__send_request_helper(
243-
path, method, http_function, no_of_tries - 1
269+
path, method, http_function, no_of_tries - 1, retry_info_map
244270
)
245271
except Exception as e:
246272
raise_general_exception(e)

tests/test_querier.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
from pytest import mark
15+
from supertokens_python.recipe import (
16+
session,
17+
emailpassword,
18+
)
19+
import asyncio
20+
import respx
21+
import httpx
22+
from supertokens_python import init, SupertokensConfig
23+
from supertokens_python.querier import Querier, NormalisedURLPath
24+
25+
from tests.utils import get_st_init_args
26+
from tests.utils import (
27+
setup_function,
28+
teardown_function,
29+
start_st,
30+
)
31+
32+
_ = setup_function
33+
_ = teardown_function
34+
35+
pytestmark = mark.asyncio
36+
respx_mock = respx.MockRouter
37+
38+
39+
async def test_network_call_is_retried_as_expected():
40+
# Test that network call is retried properly
41+
# Test that rate limiting errors are thrown back to the user
42+
args = get_st_init_args(
43+
[
44+
session.init(),
45+
emailpassword.init(),
46+
]
47+
)
48+
args["supertokens_config"] = SupertokensConfig("http://localhost:6789")
49+
init(**args) # type: ignore
50+
start_st()
51+
52+
Querier.api_version = "3.0"
53+
q = Querier.get_instance()
54+
55+
api2_call_count = 0
56+
57+
def api2_side_effect(_: httpx.Request):
58+
nonlocal api2_call_count
59+
api2_call_count += 1
60+
61+
if api2_call_count == 3:
62+
return httpx.Response(200)
63+
64+
return httpx.Response(429, json={})
65+
66+
with respx_mock() as mocker:
67+
api1 = mocker.get("http://localhost:6789/api1").mock(
68+
httpx.Response(429, json={"status": "RATE_ERROR"})
69+
)
70+
api2 = mocker.get("http://localhost:6789/api2").mock(
71+
side_effect=api2_side_effect
72+
)
73+
api3 = mocker.get("http://localhost:6789/api3").mock(httpx.Response(200))
74+
75+
try:
76+
await q.send_get_request(NormalisedURLPath("/api1"), {})
77+
except Exception as e:
78+
if "with status code: 429" in str(
79+
e
80+
) and 'message: {"status": "RATE_ERROR"}' in str(e):
81+
pass
82+
else:
83+
raise e
84+
85+
await q.send_get_request(NormalisedURLPath("/api2"), {})
86+
await q.send_get_request(NormalisedURLPath("/api3"), {})
87+
88+
# 1 initial request + 5 retries
89+
assert api1.call_count == 6
90+
# 2 403 and 1 200
91+
assert api2.call_count == 3
92+
# 200 in the first attempt
93+
assert api3.call_count == 1
94+
95+
96+
async def test_parallel_calls_have_independent_counters():
97+
args = get_st_init_args(
98+
[
99+
session.init(),
100+
emailpassword.init(),
101+
]
102+
)
103+
init(**args) # type: ignore
104+
start_st()
105+
106+
Querier.api_version = "3.0"
107+
q = Querier.get_instance()
108+
109+
call_count1 = 0
110+
call_count2 = 0
111+
112+
def api_side_effect(r: httpx.Request):
113+
nonlocal call_count1, call_count2
114+
115+
id_ = int(r.url.params.get("id"))
116+
if id_ == 1:
117+
call_count1 += 1
118+
elif id_ == 2:
119+
call_count2 += 1
120+
121+
return httpx.Response(429, json={})
122+
123+
with respx_mock() as mocker:
124+
api = mocker.get("http://localhost:3567/api").mock(side_effect=api_side_effect)
125+
126+
async def call_api(id_: int):
127+
try:
128+
await q.send_get_request(NormalisedURLPath("/api"), {"id": id_})
129+
except Exception as e:
130+
if "with status code: 429" in str(e):
131+
pass
132+
else:
133+
raise e
134+
135+
_ = await asyncio.gather(
136+
call_api(1),
137+
call_api(2),
138+
)
139+
140+
# 1 initial request + 5 retries
141+
assert call_count1 == 6
142+
assert call_count2 == 6
143+
144+
assert api.call_count == 12

tests/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,13 +365,13 @@ def email_verify_token_request(
365365
environ["SUPERTOKENS_ENV"] = "testing"
366366

367367

368-
def setup_function(_):
368+
def setup_function(_: Any):
369369
reset()
370370
clean_st()
371371
setup_st()
372372

373373

374-
def teardown_function(_):
374+
def teardown_function(_: Any):
375375
reset()
376376
clean_st()
377377

0 commit comments

Comments
 (0)