Skip to content

Commit f27dc9a

Browse files
Add support for APIRouter prefix (#429)
Co-authored-by: Jeff Albrecht <[email protected]>
1 parent b7580fe commit f27dc9a

File tree

13 files changed

+94
-25
lines changed

13 files changed

+94
-25
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
* Ability to POST an ItemCollection to the collections/{collectionId}/items route. ([#367](https://github.com/stac-utils/stac-fastapi/pull/367))
99
* Add STAC API - Collections conformance class. ([383](https://github.com/stac-utils/stac-fastapi/pull/383))
1010
* Bulk item inserts for pgstac implementation. ([411](https://github.com/stac-utils/stac-fastapi/pull/411))
11+
* Add APIRouter prefix support for pgstac implementation. ([429](https://github.com/stac-utils/stac-fastapi/pull/429))
1112
* Respect `Forwarded` or `X-Forwarded-*` request headers when building links to better accommodate load balancers and proxies.
1213

1314
### Changed

stac_fastapi/api/stac_fastapi/api/app.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def customize_openapi(self) -> Optional[Dict[str, Any]]:
336336

337337
def add_health_check(self):
338338
"""Add a health check."""
339-
mgmt_router = APIRouter()
339+
mgmt_router = APIRouter(prefix=self.app.state.router_prefix)
340340

341341
@mgmt_router.get("/_mgmt/ping")
342342
async def ping():
@@ -384,6 +384,10 @@ def __attrs_post_init__(self):
384384
self.register_core()
385385
self.app.include_router(self.router)
386386

387+
# keep link to the router prefix value
388+
router_prefix = self.router.prefix
389+
self.app.state.router_prefix = router_prefix if router_prefix else ""
390+
387391
# register extensions
388392
for ext in self.extensions:
389393
ext.register(self.app)

stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def register(self, app: FastAPI) -> None:
107107
Returns:
108108
None
109109
"""
110+
self.router.prefix = app.state.router_prefix
110111
self.router.add_api_route(
111112
name="Queryables",
112113
path="/queryables",

stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def register(self, app: FastAPI) -> None:
160160
Returns:
161161
None
162162
"""
163+
self.router.prefix = app.state.router_prefix
163164
self.register_create_item()
164165
self.register_update_item()
165166
self.register_delete_item()

stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def register(self, app: FastAPI) -> None:
116116
"""
117117
items_request_model = create_request_model("Items", base_model=Items)
118118

119-
router = APIRouter()
119+
router = APIRouter(prefix=app.state.router_prefix)
120120
router.add_api_route(
121121
name="Bulk Create Item",
122122
path="/collections/{collection_id}/bulk_items",

stac_fastapi/pgstac/stac_fastapi/pgstac/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from stac_fastapi.pgstac.utils import filter_fields
2424
from stac_fastapi.types.core import AsyncBaseCoreClient
2525
from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError
26+
from stac_fastapi.types.requests import get_base_url
2627
from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection
2728

2829
NumType = Union[float, int]
@@ -35,7 +36,7 @@ class CoreCrudClient(AsyncBaseCoreClient):
3536
async def all_collections(self, **kwargs) -> Collections:
3637
"""Read all collections from the database."""
3738
request: Request = kwargs["request"]
38-
base_url = str(request.base_url)
39+
base_url = get_base_url(request)
3940
pool = request.app.state.readpool
4041

4142
async with pool.acquire() as conn:

stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from stac_pydantic.shared import MimeTypes
99
from starlette.requests import Request
1010

11+
from stac_fastapi.types.requests import get_base_url
12+
1113
# These can be inferred from the item/collection so they aren't included in the database
1214
# Instead they are dynamically generated when querying the database using the classes defined below
1315
INFERRED_LINK_RELS = ["self", "item", "parent", "collection", "root"]
@@ -45,7 +47,7 @@ class BaseLinks:
4547
@property
4648
def base_url(self):
4749
"""Get the base url."""
48-
return str(self.request.base_url)
50+
return get_base_url(self.request)
4951

5052
@property
5153
def url(self):

stac_fastapi/pgstac/tests/api/test_api.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,24 @@ async def test_api_headers(app_client):
5151
assert resp.status_code == 200
5252

5353

54-
async def test_core_router(api_client):
55-
core_routes = set(STAC_CORE_ROUTES)
54+
async def test_core_router(api_client, app):
55+
core_routes = set()
56+
for core_route in STAC_CORE_ROUTES:
57+
method, path = core_route.split(" ")
58+
core_routes.add("{} {}".format(method, app.state.router_prefix + path))
59+
5660
api_routes = set(
5761
[f"{list(route.methods)[0]} {route.path}" for route in api_client.app.routes]
5862
)
5963
assert not core_routes - api_routes
6064

6165

62-
async def test_transactions_router(api_client):
63-
transaction_routes = set(STAC_TRANSACTION_ROUTES)
66+
async def test_transactions_router(api_client, app):
67+
transaction_routes = set()
68+
for transaction_route in STAC_TRANSACTION_ROUTES:
69+
method, path = transaction_route.split(" ")
70+
transaction_routes.add("{} {}".format(method, app.state.router_prefix + path))
71+
6472
api_routes = set(
6573
[f"{list(route.methods)[0]} {route.path}" for route in api_client.app.routes]
6674
)

stac_fastapi/pgstac/tests/conftest.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import os
44
import time
55
from typing import Callable, Dict
6+
from urllib.parse import urljoin
67

78
import asyncpg
89
import pytest
10+
from fastapi import APIRouter
911
from fastapi.responses import ORJSONResponse
1012
from httpx import AsyncClient
1113
from pypgstac.db import PgstacDB
@@ -107,9 +109,26 @@ async def pgstac(pg):
107109

108110

109111
# Run all the tests that use the api_client in both db hydrate and api hydrate mode
110-
@pytest.fixture(params=[settings, pgstac_api_hydrate_settings], scope="session")
112+
@pytest.fixture(
113+
params=[
114+
(settings, ""),
115+
(settings, "/router_prefix"),
116+
(pgstac_api_hydrate_settings, ""),
117+
(pgstac_api_hydrate_settings, "/router_prefix"),
118+
],
119+
scope="session",
120+
)
111121
def api_client(request, pg):
112-
print("creating client with settings, hydrate:", request.param.use_api_hydrate)
122+
api_settings, prefix = request.param
123+
124+
api_settings.openapi_url = prefix + api_settings.openapi_url
125+
api_settings.docs_url = prefix + api_settings.docs_url
126+
127+
print(
128+
"creating client with settings, hydrate: {}, router prefix: '{}'".format(
129+
api_settings.use_api_hydrate, prefix
130+
)
131+
)
113132

114133
extensions = [
115134
TransactionExtension(client=TransactionsClient(), settings=settings),
@@ -122,12 +141,13 @@ def api_client(request, pg):
122141
]
123142
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
124143
api = StacApi(
125-
settings=request.param,
144+
settings=api_settings,
126145
extensions=extensions,
127146
client=CoreCrudClient(post_request_model=post_request_model),
128147
search_get_request_model=create_get_request_model(extensions),
129148
search_post_request_model=post_request_model,
130149
response_class=ORJSONResponse,
150+
router=APIRouter(prefix=prefix),
131151
)
132152

133153
return api
@@ -150,7 +170,12 @@ async def app(api_client):
150170
@pytest.fixture(scope="function")
151171
async def app_client(app):
152172
print("creating app_client")
153-
async with AsyncClient(app=app, base_url="http://test") as c:
173+
174+
base_url = "http://test"
175+
if app.state.router_prefix != "":
176+
base_url = urljoin(base_url, app.state.router_prefix)
177+
178+
async with AsyncClient(app=app, base_url=base_url) as c:
154179
yield c
155180

156181

stac_fastapi/pgstac/tests/resources/test_conformance.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,25 +46,25 @@ def test_landing_page_health(response):
4646

4747
@pytest.mark.parametrize("rel_type,expected_media_type,expected_path", link_tests)
4848
async def test_landing_page_links(
49-
response_json: Dict, app_client, rel_type, expected_media_type, expected_path
49+
response_json: Dict, app_client, app, rel_type, expected_media_type, expected_path
5050
):
5151
link = get_link(response_json, rel_type)
5252

5353
assert link is not None, f"Missing {rel_type} link in landing page"
5454
assert link.get("type") == expected_media_type
5555

5656
link_path = urllib.parse.urlsplit(link.get("href")).path
57-
assert link_path == expected_path
57+
assert link_path == app.state.router_prefix + expected_path
5858

59-
resp = await app_client.get(link_path)
59+
resp = await app_client.get(link_path.rsplit("/", 1)[-1])
6060
assert resp.status_code == 200
6161

6262

6363
# This endpoint currently returns a 404 for empty result sets, but testing for this response
6464
# code here seems meaningless since it would be the same as if the endpoint did not exist. Once
6565
# https://github.com/stac-utils/stac-fastapi/pull/227 has been merged we can add this to the
6666
# parameterized tests above.
67-
def test_search_link(response_json: Dict):
67+
def test_search_link(response_json: Dict, app):
6868
for search_link in [
6969
get_link(response_json, "search", "GET"),
7070
get_link(response_json, "search", "POST"),
@@ -73,4 +73,4 @@ def test_search_link(response_json: Dict):
7373
assert search_link.get("type") == "application/geo+json"
7474

7575
search_path = urllib.parse.urlsplit(search_link.get("href")).path
76-
assert search_path == "/search"
76+
assert search_path == app.state.router_prefix + "/search"

stac_fastapi/pgstac/tests/resources/test_item.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,7 +1166,7 @@ async def test_get_missing_item(app_client, load_test_data):
11661166
assert resp.status_code == 404
11671167

11681168

1169-
async def test_relative_link_construction():
1169+
async def test_relative_link_construction(app):
11701170
req = Request(
11711171
scope={
11721172
"type": "http",
@@ -1177,11 +1177,14 @@ async def test_relative_link_construction():
11771177
"raw_path": b"/tab/abc",
11781178
"query_string": b"",
11791179
"headers": {},
1180+
"app": app,
11801181
"server": ("test", HTTP_PORT),
11811182
}
11821183
)
11831184
links = CollectionLinks(collection_id="naip", request=req)
1184-
assert links.link_items()["href"] == "http://test/stac/collections/naip/items"
1185+
assert links.link_items()["href"] == (
1186+
"http://test/stac{}/collections/naip/items".format(app.state.router_prefix)
1187+
)
11851188

11861189

11871190
async def test_search_bbox_errors(app_client):

stac_fastapi/types/stac_fastapi/types/core.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from stac_fastapi.types import stac as stac_types
1515
from stac_fastapi.types.conformance import BASE_CONFORMANCE_CLASSES
1616
from stac_fastapi.types.extension import ApiExtension
17+
from stac_fastapi.types.requests import get_base_url
1718
from stac_fastapi.types.search import BaseSearchPostRequest
1819
from stac_fastapi.types.stac import Conformance
1920

@@ -349,7 +350,7 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage:
349350
API landing page, serving as an entry point to the API.
350351
"""
351352
request: Request = kwargs["request"]
352-
base_url = str(request.base_url)
353+
base_url = get_base_url(request)
353354
extension_schemas = [
354355
schema.schema_href for schema in self.extensions if schema.schema_href
355356
]
@@ -377,7 +378,9 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage:
377378
"rel": "service-desc",
378379
"type": "application/vnd.oai.openapi+json;version=3.0",
379380
"title": "OpenAPI service description",
380-
"href": urljoin(base_url, request.app.openapi_url.lstrip("/")),
381+
"href": urljoin(
382+
str(request.base_url), request.app.openapi_url.lstrip("/")
383+
),
381384
}
382385
)
383386

@@ -387,7 +390,9 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage:
387390
"rel": "service-doc",
388391
"type": "text/html",
389392
"title": "OpenAPI service documentation",
390-
"href": urljoin(base_url, request.app.docs_url.lstrip("/")),
393+
"href": urljoin(
394+
str(request.base_url), request.app.docs_url.lstrip("/")
395+
),
391396
}
392397
)
393398

@@ -538,7 +543,7 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage:
538543
API landing page, serving as an entry point to the API.
539544
"""
540545
request: Request = kwargs["request"]
541-
base_url = str(request.base_url)
546+
base_url = get_base_url(request)
542547
extension_schemas = [
543548
schema.schema_href for schema in self.extensions if schema.schema_href
544549
]
@@ -564,7 +569,9 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage:
564569
"rel": "service-desc",
565570
"type": "application/vnd.oai.openapi+json;version=3.0",
566571
"title": "OpenAPI service description",
567-
"href": urljoin(base_url, request.app.openapi_url.lstrip("/")),
572+
"href": urljoin(
573+
str(request.base_url), request.app.openapi_url.lstrip("/")
574+
),
568575
}
569576
)
570577

@@ -574,7 +581,9 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage:
574581
"rel": "service-doc",
575582
"type": "text/html",
576583
"title": "OpenAPI service documentation",
577-
"href": urljoin(base_url, request.app.docs_url.lstrip("/")),
584+
"href": urljoin(
585+
str(request.base_url), request.app.docs_url.lstrip("/")
586+
),
578587
}
579588
)
580589

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""requests helpers."""
2+
3+
from starlette.requests import Request
4+
5+
6+
def get_base_url(request: Request) -> str:
7+
"""Get base URL with respect of APIRouter prefix."""
8+
app = request.app
9+
if not app.state.router_prefix:
10+
return str(request.base_url)
11+
else:
12+
return "{}{}/".format(
13+
str(request.base_url), app.state.router_prefix.lstrip("/")
14+
)

0 commit comments

Comments
 (0)