Skip to content

Commit 0b89b2e

Browse files
committed
Add support for APIRouter prefix
1 parent 0483406 commit 0b89b2e

File tree

12 files changed

+93
-27
lines changed

12 files changed

+93
-27
lines changed

stac_fastapi/api/stac_fastapi/api/app.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def customize_openapi(self) -> Optional[Dict[str, Any]]:
333333

334334
def add_health_check(self):
335335
"""Add a health check."""
336-
mgmt_router = APIRouter()
336+
mgmt_router = APIRouter(prefix=self.app.state.router_prefix)
337337

338338
@mgmt_router.get("/_mgmt/ping")
339339
async def ping():
@@ -381,6 +381,10 @@ def __attrs_post_init__(self):
381381
self.register_core()
382382
self.app.include_router(self.router)
383383

384+
# keep link to the router prefix value
385+
router_prefix = self.router.prefix
386+
self.app.state.router_prefix = router_prefix if router_prefix else ""
387+
384388
# register extensions
385389
for ext in self.extensions:
386390
ext.register(self.app)

stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def register(self, app: FastAPI) -> None:
100100
Returns:
101101
None
102102
"""
103+
self.router.prefix = app.state.router_prefix
103104
self.router.add_api_route(
104105
name="Queryables",
105106
path="/queryables",

stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def register(self, app: FastAPI) -> None:
160160
Returns:
161161
None
162162
"""
163+
self.router.prefix = app.state.router_prefix
163164
self.register_create_item()
164165
self.register_update_item()
165166
self.register_delete_item()

stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def register(self, app: FastAPI) -> None:
116116
"""
117117
items_request_model = create_request_model("Items", base_model=Items)
118118

119-
router = APIRouter()
119+
router = APIRouter(prefix=app.state.router_prefix)
120120
router.add_api_route(
121121
name="Bulk Create Item",
122122
path="/collections/{collection_id}/bulk_items",

stac_fastapi/pgstac/stac_fastapi/pgstac/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from stac_fastapi.pgstac.utils import filter_fields
2424
from stac_fastapi.types.core import AsyncBaseCoreClient
2525
from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError
26+
from stac_fastapi.types.requests import get_base_url
2627
from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection
2728

2829
NumType = Union[float, int]
@@ -35,7 +36,7 @@ class CoreCrudClient(AsyncBaseCoreClient):
3536
async def all_collections(self, **kwargs) -> Collections:
3637
"""Read all collections from the database."""
3738
request: Request = kwargs["request"]
38-
base_url = str(request.base_url)
39+
base_url = get_base_url(request)
3940
pool = request.app.state.readpool
4041

4142
async with pool.acquire() as conn:

stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from stac_pydantic.shared import MimeTypes
99
from starlette.requests import Request
1010

11+
from stac_fastapi.types.requests import get_base_url
12+
1113
# These can be inferred from the item/collection so they aren't included in the database
1214
# Instead they are dynamically generated when querying the database using the classes defined below
1315
INFERRED_LINK_RELS = ["self", "item", "parent", "collection", "root"]
@@ -45,7 +47,7 @@ class BaseLinks:
4547
@property
4648
def base_url(self):
4749
"""Get the base url."""
48-
return str(self.request.base_url)
50+
return get_base_url(self.request)
4951

5052
@property
5153
def url(self):

stac_fastapi/pgstac/tests/api/test_api.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,24 @@ async def test_api_headers(app_client):
4040
assert resp.status_code == 200
4141

4242

43-
async def test_core_router(api_client):
44-
core_routes = set(STAC_CORE_ROUTES)
43+
async def test_core_router(api_client, app):
44+
core_routes = set()
45+
for core_route in STAC_CORE_ROUTES:
46+
method, path = core_route.split(" ")
47+
core_routes.add("{} {}".format(method, app.state.router_prefix + path))
48+
4549
api_routes = set(
4650
[f"{list(route.methods)[0]} {route.path}" for route in api_client.app.routes]
4751
)
4852
assert not core_routes - api_routes
4953

5054

51-
async def test_transactions_router(api_client):
52-
transaction_routes = set(STAC_TRANSACTION_ROUTES)
55+
async def test_transactions_router(api_client, app):
56+
transaction_routes = set()
57+
for transaction_route in STAC_TRANSACTION_ROUTES:
58+
method, path = transaction_route.split(" ")
59+
transaction_routes.add("{} {}".format(method, app.state.router_prefix + path))
60+
5361
api_routes = set(
5462
[f"{list(route.methods)[0]} {route.path}" for route in api_client.app.routes]
5563
)

stac_fastapi/pgstac/tests/conftest.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import os
44
import time
55
from typing import Callable, Dict
6+
from urllib.parse import urljoin
67

78
import asyncpg
89
import pytest
10+
from fastapi import APIRouter
911
from fastapi.responses import ORJSONResponse
1012
from httpx import AsyncClient
1113
from pypgstac.db import PgstacDB
@@ -107,9 +109,26 @@ async def pgstac(pg):
107109

108110

109111
# Run all the tests that use the api_client in both db hydrate and api hydrate mode
110-
@pytest.fixture(params=[settings, pgstac_api_hydrate_settings], scope="session")
112+
@pytest.fixture(
113+
params=[
114+
(settings, ""),
115+
(settings, "/router_prefix"),
116+
(pgstac_api_hydrate_settings, ""),
117+
(pgstac_api_hydrate_settings, "/router_prefix"),
118+
],
119+
scope="session",
120+
)
111121
def api_client(request, pg):
112-
print("creating client with settings, hydrate:", request.param.use_api_hydrate)
122+
api_settings, prefix = request.param
123+
124+
api_settings.openapi_url = prefix + api_settings.openapi_url
125+
api_settings.docs_url = prefix + api_settings.docs_url
126+
127+
print(
128+
"creating client with settings, hydrate: {}, router prefix: '{}'".format(
129+
api_settings.use_api_hydrate, prefix
130+
)
131+
)
113132

114133
extensions = [
115134
TransactionExtension(client=TransactionsClient(), settings=settings),
@@ -122,12 +141,13 @@ def api_client(request, pg):
122141
]
123142
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
124143
api = StacApi(
125-
settings=request.param,
144+
settings=api_settings,
126145
extensions=extensions,
127146
client=CoreCrudClient(post_request_model=post_request_model),
128147
search_get_request_model=create_get_request_model(extensions),
129148
search_post_request_model=post_request_model,
130149
response_class=ORJSONResponse,
150+
router=APIRouter(prefix=prefix),
131151
)
132152

133153
return api
@@ -150,7 +170,12 @@ async def app(api_client):
150170
@pytest.fixture(scope="function")
151171
async def app_client(app):
152172
print("creating app_client")
153-
async with AsyncClient(app=app, base_url="http://test") as c:
173+
174+
base_url = "http://test"
175+
if app.state.router_prefix != "":
176+
base_url = urljoin(base_url, app.state.router_prefix)
177+
178+
async with AsyncClient(app=app, base_url=base_url) as c:
154179
yield c
155180

156181

stac_fastapi/pgstac/tests/resources/test_conformance.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,25 +46,25 @@ def test_landing_page_health(response):
4646

4747
@pytest.mark.parametrize("rel_type,expected_media_type,expected_path", link_tests)
4848
async def test_landing_page_links(
49-
response_json: Dict, app_client, rel_type, expected_media_type, expected_path
49+
response_json: Dict, app_client, app, rel_type, expected_media_type, expected_path
5050
):
5151
link = get_link(response_json, rel_type)
5252

5353
assert link is not None, f"Missing {rel_type} link in landing page"
5454
assert link.get("type") == expected_media_type
5555

5656
link_path = urllib.parse.urlsplit(link.get("href")).path
57-
assert link_path == expected_path
57+
assert link_path == app.state.router_prefix + expected_path
5858

59-
resp = await app_client.get(link_path)
59+
resp = await app_client.get(link_path.rsplit("/", 1)[-1])
6060
assert resp.status_code == 200
6161

6262

6363
# This endpoint currently returns a 404 for empty result sets, but testing for this response
6464
# code here seems meaningless since it would be the same as if the endpoint did not exist. Once
6565
# https://github.com/stac-utils/stac-fastapi/pull/227 has been merged we can add this to the
6666
# parameterized tests above.
67-
def test_search_link(response_json: Dict):
67+
def test_search_link(response_json: Dict, app):
6868
for search_link in [
6969
get_link(response_json, "search", "GET"),
7070
get_link(response_json, "search", "POST"),
@@ -73,4 +73,4 @@ def test_search_link(response_json: Dict):
7373
assert search_link.get("type") == "application/geo+json"
7474

7575
search_path = urllib.parse.urlsplit(search_link.get("href")).path
76-
assert search_path == "/search"
76+
assert search_path == app.state.router_prefix + "/search"

stac_fastapi/pgstac/tests/resources/test_item.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,7 +1164,7 @@ async def test_get_missing_item(app_client, load_test_data):
11641164
assert resp.status_code == 404
11651165

11661166

1167-
async def test_relative_link_construction():
1167+
async def test_relative_link_construction(app):
11681168
req = Request(
11691169
scope={
11701170
"type": "http",
@@ -1175,10 +1175,13 @@ async def test_relative_link_construction():
11751175
"raw_path": b"/tab/abc",
11761176
"query_string": b"",
11771177
"headers": {},
1178+
"app": app,
11781179
}
11791180
)
11801181
links = CollectionLinks(collection_id="naip", request=req)
1181-
assert links.link_items()["href"] == "http://test/stac/collections/naip/items"
1182+
assert links.link_items()["href"] == (
1183+
"http://test/stac{}/collections/naip/items".format(app.state.router_prefix)
1184+
)
11821185

11831186

11841187
async def test_search_bbox_errors(app_client):

stac_fastapi/types/stac_fastapi/types/core.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from stac_fastapi.types import stac as stac_types
1515
from stac_fastapi.types.conformance import BASE_CONFORMANCE_CLASSES
1616
from stac_fastapi.types.extension import ApiExtension
17+
from stac_fastapi.types.requests import get_base_url
1718
from stac_fastapi.types.search import BaseSearchPostRequest
1819
from stac_fastapi.types.stac import Conformance
1920

@@ -349,12 +350,10 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage:
349350
API landing page, serving as an entry point to the API.
350351
"""
351352
request: Request = kwargs["request"]
352-
base_url = str(request.base_url)
353+
base_url = get_base_url(request)
353354
extension_schemas = [
354355
schema.schema_href for schema in self.extensions if schema.schema_href
355356
]
356-
request: Request = kwargs["request"]
357-
base_url = str(request.base_url)
358357
landing_page = self._landing_page(
359358
base_url=base_url,
360359
conformance_classes=self.conformance_classes(),
@@ -379,7 +378,9 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage:
379378
"rel": "service-desc",
380379
"type": "application/vnd.oai.openapi+json;version=3.0",
381380
"title": "OpenAPI service description",
382-
"href": urljoin(base_url, request.app.openapi_url.lstrip("/")),
381+
"href": urljoin(
382+
str(request.base_url), request.app.openapi_url.lstrip("/")
383+
),
383384
}
384385
)
385386

@@ -389,7 +390,9 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage:
389390
"rel": "service-doc",
390391
"type": "text/html",
391392
"title": "OpenAPI service documentation",
392-
"href": urljoin(base_url, request.app.docs_url.lstrip("/")),
393+
"href": urljoin(
394+
str(request.base_url), request.app.docs_url.lstrip("/")
395+
),
393396
}
394397
)
395398

@@ -540,7 +543,7 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage:
540543
API landing page, serving as an entry point to the API.
541544
"""
542545
request: Request = kwargs["request"]
543-
base_url = str(request.base_url)
546+
base_url = get_base_url(request)
544547
extension_schemas = [
545548
schema.schema_href for schema in self.extensions if schema.schema_href
546549
]
@@ -566,7 +569,9 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage:
566569
"rel": "service-desc",
567570
"type": "application/vnd.oai.openapi+json;version=3.0",
568571
"title": "OpenAPI service description",
569-
"href": urljoin(base_url, request.app.openapi_url.lstrip("/")),
572+
"href": urljoin(
573+
str(request.base_url), request.app.openapi_url.lstrip("/")
574+
),
570575
}
571576
)
572577

@@ -576,7 +581,9 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage:
576581
"rel": "service-doc",
577582
"type": "text/html",
578583
"title": "OpenAPI service documentation",
579-
"href": urljoin(base_url, request.app.docs_url.lstrip("/")),
584+
"href": urljoin(
585+
str(request.base_url), request.app.docs_url.lstrip("/")
586+
),
580587
}
581588
)
582589

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""requests helpers."""
2+
3+
from starlette.requests import Request
4+
5+
6+
def get_base_url(request: Request) -> str:
7+
"""Get base URL with respect of APIRouter prefix."""
8+
app = request.app
9+
if not app.state.router_prefix:
10+
return str(request.base_url)
11+
else:
12+
return "{}{}/".format(
13+
str(request.base_url), app.state.router_prefix.lstrip("/")
14+
)

0 commit comments

Comments
 (0)