Skip to content

adds caching for querier #508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

## [0.22.0] - 2024-06-05
- Adds caching per API based on user context.

### Breaking change:
- Changes general error in querier to normal python error.

## [0.21.0] - 2024-05-23

### Breaking change
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@

setup(
name="supertokens_python",
version="0.21.0",
version="0.22.0",
author="SuperTokens",
license="Apache 2.0",
author_email="[email protected]",
Expand Down
2 changes: 1 addition & 1 deletion supertokens_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

SUPPORTED_CDI_VERSIONS = ["3.0"]
VERSION = "0.21.0"
VERSION = "0.22.0"
TELEMETRY = "/telemetry"
USER_COUNT = "/users/count"
USER_DELETE = "/user/remove"
Expand Down
105 changes: 91 additions & 14 deletions supertokens_python/querier.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@

from typing import List, Set, Union

from .exceptions import raise_general_exception
from .process_state import AllowedProcessStates, ProcessState
from .utils import find_max_version, is_4xx_error, is_5xx_error
from sniffio import AsyncLibraryNotFoundError
from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop
from supertokens_python.utils import get_timestamp_ms


class Querier:
Expand Down Expand Up @@ -68,10 +68,13 @@ class Querier:
],
]
] = None
__global_cache_tag = get_timestamp_ms()
__disable_cache = False

def __init__(self, hosts: List[Host], rid_to_core: Union[None, str] = None):
self.__hosts = hosts
self.__rid_to_core = None
self.__global_cache_tag = get_timestamp_ms()
if rid_to_core is not None:
self.__rid_to_core = rid_to_core

Expand All @@ -80,15 +83,15 @@ def reset():
if ("SUPERTOKENS_ENV" not in environ) or (
environ["SUPERTOKENS_ENV"] != "testing"
):
raise_general_exception("calling testing function in non testing env")
raise Exception("calling testing function in non testing env")
Querier.__init_called = False

@staticmethod
def get_hosts_alive_for_testing():
if ("SUPERTOKENS_ENV" not in environ) or (
environ["SUPERTOKENS_ENV"] != "testing"
):
raise_general_exception("calling testing function in non testing env")
raise Exception("calling testing function in non testing env")
return Querier.__hosts_alive_for_testing

async def api_request(
Expand All @@ -100,7 +103,7 @@ async def api_request(
**kwargs: Any,
) -> Response:
if attempts_remaining == 0:
raise_general_exception("Retry request failed")
raise Exception("Retry request failed")

try:
async with AsyncClient() as client:
Expand Down Expand Up @@ -141,7 +144,7 @@ async def f(url: str, method: str) -> Response:
api_version = find_max_version(cdi_supported_by_server, SUPPORTED_CDI_VERSIONS)

if api_version is None:
raise_general_exception(
raise Exception(
"The running SuperTokens core version is not compatible with this python "
"SDK. Please visit https://supertokens.io/docs/community/compatibility-table "
"to find the right versions"
Expand All @@ -152,7 +155,7 @@ async def f(url: str, method: str) -> Response:

@staticmethod
def get_instance(rid_to_core: Union[str, None] = None):
if (not Querier.__init_called) or (Querier.__hosts is None):
if not Querier.__init_called:
raise Exception(
"Please call the supertokens.init function before using SuperTokens"
)
Expand Down Expand Up @@ -181,6 +184,7 @@ def init(
],
]
] = None,
disable_cache: bool = False,
):
if not Querier.__init_called:
Querier.__init_called = True
Expand All @@ -190,6 +194,7 @@ def init(
Querier.__last_tried_index = 0
Querier.__hosts_alive_for_testing = set()
Querier.network_interceptor = network_interceptor
Querier.__disable_cache = disable_cache

async def __get_headers_with_api_version(self, path: NormalisedURLPath):
headers = {API_VERSION_HEADER: await self.get_api_version()}
Expand All @@ -211,6 +216,41 @@ async def send_get_request(
async def f(url: str, method: str) -> Response:
headers = await self.__get_headers_with_api_version(path)
nonlocal params

assert params is not None

# Sort the keys for deterministic order
sorted_keys = sorted(params.keys())
sorted_header_keys = sorted(headers.keys())

# Start with the path as the unique key
unique_key = path.get_as_string_dangerous()

# Append sorted params to the unique key
for key in sorted_keys:
value = params[key]
unique_key += f";{key}={value}"

# Append a separator for headers
unique_key += ";hdrs"

# Append sorted headers to the unique key
for key in sorted_header_keys:
value = headers[key]
unique_key += f";{key}={value}"

if user_context is not None:
if (
user_context.get("_default", {}).get("global_cache_tag", -1)
!= self.__global_cache_tag
):
self.invalidate_core_call_cache(user_context, False)

if not Querier.__disable_cache and unique_key in user_context.get(
"_default", {}
).get("core_call_cache", {}):
return user_context["_default"]["core_call_cache"][unique_key]

if Querier.network_interceptor is not None:
(
url,
Expand All @@ -222,14 +262,30 @@ async def f(url: str, method: str) -> Response:
url, method, headers, params, {}, user_context
)

return await self.api_request(
response = await self.api_request(
url,
method,
2,
headers=headers,
params=params,
)

if (
response.status_code == 200
and not Querier.__disable_cache
and user_context is not None
):
user_context["_default"] = {
**user_context.get("_default", {}),
"core_call_cache": {
**user_context.get("_default", {}).get("core_call_cache", {}),
unique_key: response,
},
"global_cache_tag": self.__global_cache_tag,
}

return response

return await self.__send_request_helper(path, "GET", f, len(self.__hosts))

async def send_post_request(
Expand All @@ -239,6 +295,7 @@ async def send_post_request(
user_context: Union[Dict[str, Any], None],
test: bool = False,
) -> Dict[str, Any]:
self.invalidate_core_call_cache(user_context)
if data is None:
data = {}

Expand Down Expand Up @@ -280,6 +337,8 @@ async def send_delete_request(
params: Union[Dict[str, Any], None],
user_context: Union[Dict[str, Any], None],
) -> Dict[str, Any]:
if user_context is not None:
self.invalidate_core_call_cache(user_context)
if params is None:
params = {}

Expand Down Expand Up @@ -312,6 +371,7 @@ async def send_put_request(
data: Union[Dict[str, Any], None],
user_context: Union[Dict[str, Any], None],
) -> Dict[str, Any]:
self.invalidate_core_call_cache(user_context)
if data is None:
data = {}

Expand All @@ -334,10 +394,29 @@ async def f(url: str, method: str) -> Response:

return await self.__send_request_helper(path, "PUT", f, len(self.__hosts))

def get_all_core_urls_for_path(self, path: str) -> List[str]:
if self.__hosts is None:
return []
def invalidate_core_call_cache(
self,
user_context: Union[Dict[str, Any], None],
upd_global_cache_tag_if_necessary: bool = True,
):
if user_context is None:
# this is done so that the code below runs as expected.
# It will reset the __global_cache_tag if needed, and the
# stuff we assign to the user_context will just be ignored (as expected)
user_context = {}

if upd_global_cache_tag_if_necessary and (
user_context.get("_default", {}).get("keep_cache_alive", False) is not True
):
# there can be race conditions here, but i think we can ignore them.
self.__global_cache_tag = get_timestamp_ms()

user_context["_default"] = {
**user_context.get("_default", {}),
"core_call_cache": {},
}

def get_all_core_urls_for_path(self, path: str) -> List[str]:
normalized_path = NormalisedURLPath(path)

result: List[str] = []
Expand All @@ -362,7 +441,7 @@ async def __send_request_helper(
retry_info_map: Optional[Dict[str, int]] = None,
) -> Dict[str, Any]:
if no_of_tries == 0:
raise_general_exception("No SuperTokens core available to query")
raise Exception("No SuperTokens core available to query")

try:
current_host_domain = self.__hosts[
Expand Down Expand Up @@ -408,7 +487,7 @@ async def __send_request_helper(
)

if is_4xx_error(response.status_code) or is_5xx_error(response.status_code): # type: ignore
raise_general_exception(
raise Exception(
"SuperTokens core threw an error for a "
+ method
+ " request to path: "
Expand All @@ -432,5 +511,3 @@ async def __send_request_helper(
return await self.__send_request_helper(
path, method, http_function, no_of_tries - 1, retry_info_map
)
except Exception as e:
raise_general_exception(e)
7 changes: 6 additions & 1 deletion supertokens_python/supertokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,12 @@ def __init__(
],
]
] = None,
disable_core_call_cache: bool = False,
): # We keep this = None here because this is directly used by the user.
self.connection_uri = connection_uri
self.api_key = api_key
self.network_interceptor = network_interceptor
self.disable_core_call_cache = disable_core_call_cache


class Host:
Expand Down Expand Up @@ -243,7 +245,10 @@ def __init__(
)
)
Querier.init(
hosts, supertokens_config.api_key, supertokens_config.network_interceptor
hosts,
supertokens_config.api_key,
supertokens_config.network_interceptor,
supertokens_config.disable_core_call_cache,
)

if len(recipe_list) == 0:
Expand Down
1 change: 1 addition & 0 deletions supertokens_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def set_request_in_user_context_if_not_defined(

if isinstance(user_context["_default"], dict):
user_context["_default"]["request"] = request
user_context["_default"]["keep_cache_alive"] = True

return user_context

Expand Down
Loading
Loading