Skip to content

feat: allow disabling tldextract HTTP requests #563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ __pycache__
releasePassword
apiPassword
venv/
env/
./env/
.env
.DS_Store
bin/
Expand Down
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0


## [unreleased]
- Upgrades `pip` and `setuptools` in CI publish job
- Also upgrades `poetry` and it's dependency - `clikit`

## [0.29.0] - 2025-03-03
- Adds option to disable `tldextract` HTTP calls by setting `SUPERTOKENS_TLDEXTRACT_DISABLE_HTTP=1`
- Upgrades `pip` and `setuptools` in CI publish job
- Also upgrades `poetry` and it's dependency - `clikit`
- Migrates unit tests to use a containerized core
- Updates `Makefile` to use a Docker `compose` setup step
- Migrates unit tests from CircleCI to Github Actions
Expand Down
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ flask-cors==5.0.0
nest-asyncio==1.6.0
pdoc3==0.11.0
pre-commit==3.5.0
pyfakefs==5.7.4
pylint==3.2.7
pyright==1.1.393
python-dotenv==1.0.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
"PyJWT[crypto]>=2.5.0,<3.0.0",
"httpx>=0.15.0,<1.0.0",
"pycryptodome<3.21.0",
"tldextract<5.1.3",
"tldextract<6.0.0",
"asgiref>=3.4.1,<4",
"typing_extensions>=4.1.1,<5.0.0",
"Deprecated<1.3.0",
Expand Down
Empty file.
12 changes: 12 additions & 0 deletions supertokens_python/env/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from os import environ

from supertokens_python.env.utils import str_to_bool


def FLAG_tldextract_disable_http():
"""
Disable HTTP calls from `tldextract`.
"""
val = environ.get("SUPERTOKENS_TLDEXTRACT_DISABLE_HTTP", "0")

return str_to_bool(val)
5 changes: 5 additions & 0 deletions supertokens_python/env/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def str_to_bool(val: str) -> bool:
"""
Convert ENV values to boolean
"""
return val.lower() in ("true", "t", "1")
14 changes: 12 additions & 2 deletions supertokens_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
from urllib.parse import urlparse

from httpx import HTTPStatusError, Response
from tldextract import extract # type: ignore
from tldextract import TLDExtract

from supertokens_python.env.base import FLAG_tldextract_disable_http
from supertokens_python.framework.django.framework import DjangoFramework
from supertokens_python.framework.fastapi.framework import FastapiFramework
from supertokens_python.framework.flask.framework import FlaskFramework
Expand Down Expand Up @@ -288,7 +289,16 @@ def get_top_level_domain_for_same_site_resolution(url: str) -> str:
if hostname.startswith("localhost") or is_an_ip_address(hostname):
return "localhost"

parsed_url: Any = extract(hostname, include_psl_private_domains=True)
extract = TLDExtract(fallback_to_snapshot=True, include_psl_private_domains=True)
# Explicitly disable HTTP calls, use snapshot bundled into library
if FLAG_tldextract_disable_http():
extract = TLDExtract(
suffix_list_urls=(), # Ensures no HTTP calls
fallback_to_snapshot=True,
include_psl_private_domains=True,
)

parsed_url: Any = extract(hostname)
if parsed_url.domain == "": # type: ignore
# We need to do this because of https://github.com/supertokens/supertokens-python/issues/394
if hostname.endswith(".amazonaws.com") and parsed_url.suffix == hostname:
Expand Down
53 changes: 47 additions & 6 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import os
import threading
from contextlib import ExitStack
from typing import Any, Dict, List, Union
from unittest.mock import patch

import pytest
from pytest import mark, param, raises
from supertokens_python.utils import (
RWMutex,
get_top_level_domain_for_same_site_resolution,
humanize_time,
is_version_gte,
)

from tests.utils import is_subset
from tests.utils import is_subset, outputs


@pytest.mark.parametrize(
@mark.parametrize(
"version,min_minor_version,is_gte",
[
(
Expand Down Expand Up @@ -72,7 +75,7 @@ def test_util_is_version_gte(version: str, min_minor_version: str, is_gte: bool)
HOUR = 60 * MINUTE


@pytest.mark.parametrize(
@mark.parametrize(
"ms,out",
[
(1 * SECOND, "1 second"),
Expand All @@ -91,7 +94,7 @@ def test_humanize_time(ms: int, out: str):
assert humanize_time(ms) == out


@pytest.mark.parametrize(
@mark.parametrize(
"d1,d2,result",
[
({"a": {"b": [1, 2]}, "c": 1}, {"c": 1}, True),
Expand Down Expand Up @@ -176,7 +179,7 @@ def balance_is_valid():
assert actual_balance == expected_balance, "Incorrect account balance"


@pytest.mark.parametrize(
@mark.parametrize(
"url,res",
[
("http://localhost:3001", "localhost"),
Expand All @@ -196,3 +199,41 @@ def balance_is_valid():
)
def test_tld_for_same_site(url: str, res: str):
assert get_top_level_domain_for_same_site_resolution(url) == res


@mark.parametrize(
["internet_disabled", "env_val", "expectation"],
[
param(True, "False", raises(RuntimeError), id="Internet disabled, flag unset"),
param(True, "True", outputs("google.com"), id="Internet disabled, flag set"),
param(False, "False", outputs("google.com"), id="Internet enabled, flag unset"),
param(False, "True", outputs("google.com"), id="Internet enabled, flag set"),
],
)
def test_tldextract_http_toggle(
internet_disabled: bool,
env_val: str,
expectation: Any,
# pyfakefs fixture, mocks the filesystem
# Mocking `tldextract`'s cache path does not work in repeated tests
fs: Any,
):
import socket

# Disable sockets, will raise errors on HTTP calls
socket_patch = patch.object(socket, "socket", side_effect=RuntimeError)
environ_patch = patch.dict(
os.environ,
{"SUPERTOKENS_TLDEXTRACT_DISABLE_HTTP": env_val},
)

stack = ExitStack()
stack.enter_context(environ_patch)
if internet_disabled:
stack.enter_context(socket_patch)

# if `expectation` is raises, checks for raise
# if `outputs`, value used in `assert` statement
with stack, expectation as expected_output:
output = get_top_level_domain_for_same_site_resolution("https://google.com")
assert output == expected_output
15 changes: 15 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# Import AsyncMock
import sys
from contextlib import contextmanager
from datetime import datetime
from functools import lru_cache
from http.cookies import SimpleCookie
Expand Down Expand Up @@ -487,3 +488,17 @@ async def create_users(
await manually_create_or_update_user(
"public", user["provider"], user["userId"], user["email"], True, None
)


@contextmanager
def outputs(val: Any):
"""
Outputs a value to assert.

Usage:
@mark.parametrize(["input", "expectation"], [(1, outputs(1)), (0, raises(Exception))])
def test_fn(input, expectation):
with expectation as expected_output:
assert 1 / input == expected_output
"""
yield val