Skip to content

Commit 365defa

Browse files
committed
feat: allow disabling tldextract HTTP requests
- Workaround till tldextract#233 is implemented - Adds flag to control disabling HTTP requests - Adds `pyfakefs` to use in tests
1 parent 68b9356 commit 365defa

File tree

12 files changed

+99
-12
lines changed

12 files changed

+99
-12
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ __pycache__
88
releasePassword
99
apiPassword
1010
venv/
11-
env/
11+
./env/
1212
.env
1313
.DS_Store
1414
bin/

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
## [unreleased]
1010

11+
## [0.28.1] - 2025-02-17
12+
- Adds option to disable `tldextract` HTTP calls by setting `SUPERTOKENS_TLDEXTRACT_DISABLE_HTTP=1`
13+
1114
## [0.28.0]
1215
- **[Breaking] Updates pre-commit hooks to use `pre-commit`**
1316
- Migration:

dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ flask-cors==5.0.0
99
nest-asyncio==1.6.0
1010
pdoc3==0.11.0
1111
pre-commit==3.5.0
12+
pyfakefs==5.7.4
1213
pylint==3.2.7
1314
pyright==1.1.393
1415
python-dotenv==1.0.1

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383

8484
setup(
8585
name="supertokens_python",
86-
version="0.28.0",
86+
version="0.28.1",
8787
author="SuperTokens",
8888
license="Apache 2.0",
8989
author_email="[email protected]",

supertokens_python/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
SUPPORTED_CDI_VERSIONS = ["5.2"]
18-
VERSION = "0.28.0"
18+
VERSION = "0.28.1"
1919
TELEMETRY = "/telemetry"
2020
USER_COUNT = "/users/count"
2121
USER_DELETE = "/user/remove"

supertokens_python/env/__init__.py

Whitespace-only changes.

supertokens_python/env/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from os import environ
2+
3+
from supertokens_python.env.utils import str_to_bool
4+
5+
6+
def FLAG_tldextract_disable_http():
7+
"""
8+
Disable HTTP calls from `tldextract`.
9+
"""
10+
val = environ.get("SUPERTOKENS_TLDEXTRACT_DISABLE_HTTP", "0")
11+
12+
return str_to_bool(val)

supertokens_python/env/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
def str_to_bool(val: str) -> bool:
2+
"""
3+
Convert ENV values to boolean
4+
"""
5+
return val.lower() in ("true", "t", "1")

supertokens_python/utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
from urllib.parse import urlparse
3636

3737
from httpx import HTTPStatusError, Response
38-
from tldextract import extract # type: ignore
38+
from tldextract import TLDExtract
3939

40+
from supertokens_python.env.base import FLAG_tldextract_disable_http
4041
from supertokens_python.framework.django.framework import DjangoFramework
4142
from supertokens_python.framework.fastapi.framework import FastapiFramework
4243
from supertokens_python.framework.flask.framework import FlaskFramework
@@ -288,7 +289,16 @@ def get_top_level_domain_for_same_site_resolution(url: str) -> str:
288289
if hostname.startswith("localhost") or is_an_ip_address(hostname):
289290
return "localhost"
290291

291-
parsed_url: Any = extract(hostname, include_psl_private_domains=True)
292+
extract = TLDExtract(fallback_to_snapshot=True, include_psl_private_domains=True)
293+
# Explicitly disable HTTP calls, use snapshot bundled into library
294+
if FLAG_tldextract_disable_http():
295+
extract = TLDExtract(
296+
suffix_list_urls=(), # Ensures no HTTP calls
297+
fallback_to_snapshot=True,
298+
include_psl_private_domains=True,
299+
)
300+
301+
parsed_url: Any = extract(hostname)
292302
if parsed_url.domain == "": # type: ignore
293303
# We need to do this because of https://github.com/supertokens/supertokens-python/issues/394
294304
if hostname.endswith(".amazonaws.com") and parsed_url.suffix == hostname:

tests/auth-react/fastapi-server/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1117,7 +1117,7 @@ async def exception_handler(a, b): # type: ignore
11171117
return JSONResponse(status_code=500, content={})
11181118

11191119

1120-
app.add_middleware(ExceptionMiddleware, handlers=app.exception_handlers)
1120+
app.add_middleware(ExceptionMiddleware, handlers=app.exception_handlers) # type: ignore
11211121

11221122

11231123
@app.post("/beforeeach")

tests/test_utils.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
1+
import os
12
import threading
3+
from contextlib import ExitStack
24
from typing import Any, Dict, List, Union
5+
from unittest.mock import patch
36

4-
import pytest
7+
from pytest import mark, param, raises
58
from supertokens_python.utils import (
69
RWMutex,
710
get_top_level_domain_for_same_site_resolution,
811
humanize_time,
912
is_version_gte,
1013
)
1114

12-
from tests.utils import is_subset
15+
from tests.utils import is_subset, outputs
1316

1417

15-
@pytest.mark.parametrize(
18+
@mark.parametrize(
1619
"version,min_minor_version,is_gte",
1720
[
1821
(
@@ -72,7 +75,7 @@ def test_util_is_version_gte(version: str, min_minor_version: str, is_gte: bool)
7275
HOUR = 60 * MINUTE
7376

7477

75-
@pytest.mark.parametrize(
78+
@mark.parametrize(
7679
"ms,out",
7780
[
7881
(1 * SECOND, "1 second"),
@@ -91,7 +94,7 @@ def test_humanize_time(ms: int, out: str):
9194
assert humanize_time(ms) == out
9295

9396

94-
@pytest.mark.parametrize(
97+
@mark.parametrize(
9598
"d1,d2,result",
9699
[
97100
({"a": {"b": [1, 2]}, "c": 1}, {"c": 1}, True),
@@ -176,7 +179,7 @@ def balance_is_valid():
176179
assert actual_balance == expected_balance, "Incorrect account balance"
177180

178181

179-
@pytest.mark.parametrize(
182+
@mark.parametrize(
180183
"url,res",
181184
[
182185
("http://localhost:3001", "localhost"),
@@ -196,3 +199,41 @@ def balance_is_valid():
196199
)
197200
def test_tld_for_same_site(url: str, res: str):
198201
assert get_top_level_domain_for_same_site_resolution(url) == res
202+
203+
204+
@mark.parametrize(
205+
["internet_disabled", "env_val", "expectation"],
206+
[
207+
param(True, "False", raises(RuntimeError), id="Internet disabled, flag unset"),
208+
param(True, "True", outputs("google.com"), id="Internet disabled, flag set"),
209+
param(False, "False", outputs("google.com"), id="Internet enabled, flag unset"),
210+
param(False, "True", outputs("google.com"), id="Internet enabled, flag set"),
211+
],
212+
)
213+
def test_tldextract_http_toggle(
214+
internet_disabled: bool,
215+
env_val: str,
216+
expectation: Any,
217+
# pyfakefs fixture, mocks the filesystem
218+
# Mocking `tldextract`'s cache path does not work in repeated tests
219+
fs: Any,
220+
):
221+
import socket
222+
223+
# Disable sockets, will raise errors on HTTP calls
224+
socket_patch = patch.object(socket, "socket", side_effect=RuntimeError)
225+
environ_patch = patch.dict(
226+
os.environ,
227+
{"SUPERTOKENS_TLDEXTRACT_DISABLE_HTTP": env_val},
228+
)
229+
230+
stack = ExitStack()
231+
stack.enter_context(environ_patch)
232+
if internet_disabled:
233+
stack.enter_context(socket_patch)
234+
235+
# if `expectation` is raises, checks for raise
236+
# if `outputs`, value used in `assert` statement
237+
with stack, expectation as expected_output:
238+
output = get_top_level_domain_for_same_site_resolution("https://google.com")
239+
assert output == expected_output

tests/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
# Import AsyncMock
1818
import sys
19+
from contextlib import contextmanager
1920
from datetime import datetime
2021
from http.cookies import SimpleCookie
2122
from os import environ, kill, remove, scandir
@@ -613,3 +614,17 @@ async def create_users(
613614
await manually_create_or_update_user(
614615
"public", user["provider"], user["userId"], user["email"], True, None
615616
)
617+
618+
619+
@contextmanager
620+
def outputs(val: Any):
621+
"""
622+
Outputs a value to assert.
623+
624+
Usage:
625+
@mark.parametrize(["input", "expectation"], [(1, outputs(1)), (0, raises(Exception))])
626+
def test_fn(input, expectation):
627+
with expectation as expected_output:
628+
assert 1 / input == expected_output
629+
"""
630+
yield val

0 commit comments

Comments
 (0)