Skip to content

Commit 7c74be5

Browse files
Merge pull request #440 from supertokens/fix/async-lib-not-found-err
fix: Async lib not found error
2 parents dbca4da + d76e8fb commit 7c74be5

File tree

4 files changed

+72
-35
lines changed

4 files changed

+72
-35
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
## [unreleased]
1010

11+
- Retry Querier request on `AsyncLibraryNotFoundError`
12+
1113
## [0.14.10] - 2023-09-31
1214

13-
- Uses nest_asyncio patch in event loop - sync to async
15+
- Uses `nest_asyncio` patch in event loop - sync to async
1416

1517
## [0.14.9] - 2023-09-28
1618

supertokens_python/async_to_sync_wrapper.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,18 @@
1919
_T = TypeVar("_T")
2020

2121

22-
def check_event_loop():
22+
def create_or_get_event_loop() -> asyncio.AbstractEventLoop:
2323
try:
24-
asyncio.get_event_loop()
25-
except RuntimeError as ex:
24+
return asyncio.get_event_loop()
25+
except Exception as ex:
2626
if "There is no current event loop in thread" in str(ex):
2727
loop = asyncio.new_event_loop()
2828
nest_asyncio.apply(loop) # type: ignore
2929
asyncio.set_event_loop(loop)
30+
return loop
31+
raise ex
3032

3133

3234
def sync(co: Coroutine[Any, Any, _T]) -> _T:
33-
check_event_loop()
34-
loop = asyncio.get_event_loop()
35+
loop = create_or_get_event_loop()
3536
return loop.run_until_complete(co)

supertokens_python/querier.py

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from .exceptions import raise_general_exception
4040
from .process_state import AllowedProcessStates, ProcessState
4141
from .utils import find_max_version, is_4xx_error, is_5xx_error
42+
from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop
43+
from sniffio import AsyncLibraryNotFoundError
4244

4345

4446
class Querier:
@@ -71,6 +73,35 @@ def get_hosts_alive_for_testing():
7173
raise_general_exception("calling testing function in non testing env")
7274
return Querier.__hosts_alive_for_testing
7375

76+
async def api_request(
77+
self,
78+
url: str,
79+
method: str,
80+
attempts_remaining: int,
81+
*args: Any,
82+
**kwargs: Any,
83+
) -> Response:
84+
if attempts_remaining == 0:
85+
raise_general_exception("Retry request failed")
86+
87+
try:
88+
async with AsyncClient() as client:
89+
if method == "GET":
90+
return await client.get(url, *args, **kwargs) # type: ignore
91+
if method == "POST":
92+
return await client.post(url, *args, **kwargs) # type: ignore
93+
if method == "PUT":
94+
return await client.put(url, *args, **kwargs) # type: ignore
95+
if method == "DELETE":
96+
return await client.delete(url, *args, **kwargs) # type: ignore
97+
raise Exception("Shouldn't come here")
98+
except AsyncLibraryNotFoundError:
99+
# Retry
100+
loop = create_or_get_event_loop()
101+
return loop.run_until_complete(
102+
self.api_request(url, method, attempts_remaining - 1, *args, **kwargs)
103+
)
104+
74105
async def get_api_version(self):
75106
if Querier.api_version is not None:
76107
return Querier.api_version
@@ -79,12 +110,11 @@ async def get_api_version(self):
79110
AllowedProcessStates.CALLING_SERVICE_IN_GET_API_VERSION
80111
)
81112

82-
async def f(url: str) -> Response:
113+
async def f(url: str, method: str) -> Response:
83114
headers = {}
84115
if Querier.__api_key is not None:
85116
headers = {API_KEY_HEADER: Querier.__api_key}
86-
async with AsyncClient() as client:
87-
return await client.get(url, headers=headers) # type:ignore
117+
return await self.api_request(url, method, 2, headers=headers)
88118

89119
response = await self.__send_request_helper(
90120
NormalisedURLPath(API_VERSION), "GET", f, len(self.__hosts)
@@ -134,13 +164,14 @@ async def send_get_request(
134164
if params is None:
135165
params = {}
136166

137-
async def f(url: str) -> Response:
138-
async with AsyncClient() as client:
139-
return await client.get( # type:ignore
140-
url,
141-
params=params,
142-
headers=await self.__get_headers_with_api_version(path),
143-
)
167+
async def f(url: str, method: str) -> Response:
168+
return await self.api_request(
169+
url,
170+
method,
171+
2,
172+
headers=await self.__get_headers_with_api_version(path),
173+
params=params,
174+
)
144175

145176
return await self.__send_request_helper(path, "GET", f, len(self.__hosts))
146177

@@ -163,9 +194,14 @@ async def send_post_request(
163194
headers = await self.__get_headers_with_api_version(path)
164195
headers["content-type"] = "application/json; charset=utf-8"
165196

166-
async def f(url: str) -> Response:
167-
async with AsyncClient() as client:
168-
return await client.post(url, json=data, headers=headers) # type: ignore
197+
async def f(url: str, method: str) -> Response:
198+
return await self.api_request(
199+
url,
200+
method,
201+
2,
202+
headers=await self.__get_headers_with_api_version(path),
203+
json=data,
204+
)
169205

170206
return await self.__send_request_helper(path, "POST", f, len(self.__hosts))
171207

@@ -175,13 +211,14 @@ async def send_delete_request(
175211
if params is None:
176212
params = {}
177213

178-
async def f(url: str) -> Response:
179-
async with AsyncClient() as client:
180-
return await client.delete( # type:ignore
181-
url,
182-
params=params,
183-
headers=await self.__get_headers_with_api_version(path),
184-
)
214+
async def f(url: str, method: str) -> Response:
215+
return await self.api_request(
216+
url,
217+
method,
218+
2,
219+
headers=await self.__get_headers_with_api_version(path),
220+
params=params,
221+
)
185222

186223
return await self.__send_request_helper(path, "DELETE", f, len(self.__hosts))
187224

@@ -194,9 +231,8 @@ async def send_put_request(
194231
headers = await self.__get_headers_with_api_version(path)
195232
headers["content-type"] = "application/json; charset=utf-8"
196233

197-
async def f(url: str) -> Response:
198-
async with AsyncClient() as client:
199-
return await client.put(url, json=data, headers=headers) # type: ignore
234+
async def f(url: str, method: str) -> Response:
235+
return await self.api_request(url, method, 2, headers=headers, json=data)
200236

201237
return await self.__send_request_helper(path, "PUT", f, len(self.__hosts))
202238

@@ -223,7 +259,7 @@ async def __send_request_helper(
223259
self,
224260
path: NormalisedURLPath,
225261
method: str,
226-
http_function: Callable[[str], Awaitable[Response]],
262+
http_function: Callable[[str, str], Awaitable[Response]],
227263
no_of_tries: int,
228264
retry_info_map: Optional[Dict[str, int]] = None,
229265
) -> Any:
@@ -253,7 +289,7 @@ async def __send_request_helper(
253289
ProcessState.get_instance().add_state(
254290
AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER
255291
)
256-
response = await http_function(url)
292+
response = await http_function(url, method)
257293
if ("SUPERTOKENS_ENV" in environ) and (
258294
environ["SUPERTOKENS_ENV"] == "testing"
259295
):
@@ -289,7 +325,6 @@ async def __send_request_helper(
289325
return response.json()
290326
except JSONDecodeError:
291327
return response.text
292-
293328
except (ConnectionError, NetworkError, ConnectTimeout) as _:
294329
return await self.__send_request_helper(
295330
path, method, http_function, no_of_tries - 1, retry_info_map

supertokens_python/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from httpx import HTTPStatusError, Response
4040
from tldextract import extract # type: ignore
4141

42-
from supertokens_python.async_to_sync_wrapper import check_event_loop
42+
from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop
4343
from supertokens_python.framework.django.framework import DjangoFramework
4444
from supertokens_python.framework.fastapi.framework import FastapiFramework
4545
from supertokens_python.framework.flask.framework import FlaskFramework
@@ -212,8 +212,7 @@ def execute_async(mode: str, func: Callable[[], Coroutine[Any, Any, None]]):
212212
if real_mode == "wsgi":
213213
asyncio.run(func())
214214
else:
215-
check_event_loop()
216-
loop = asyncio.get_event_loop()
215+
loop = create_or_get_event_loop()
217216
loop.create_task(func())
218217

219218

0 commit comments

Comments
 (0)