Skip to content

Commit 7f12a57

Browse files
committed
Add support for APIRouter prefix
1 parent 1faabd3 commit 7f12a57

File tree

13 files changed

+94
-25
lines changed

13 files changed

+94
-25
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
* Ability to POST an ItemCollection to the collections/{collectionId}/items route. ([#367](https://github.com/stac-utils/stac-fastapi/pull/367))
99
* Add STAC API - Collections conformance class. ([383](https://github.com/stac-utils/stac-fastapi/pull/383))
1010
* Bulk item inserts for pgstac implementation. ([411](https://github.com/stac-utils/stac-fastapi/pull/411))
11+
* Add APIRouter prefix support for pgstac implementation. ([429](https://github.com/stac-utils/stac-fastapi/pull/429))
1112

1213
### Changed
1314

stac_fastapi/api/stac_fastapi/api/app.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def customize_openapi(self) -> Optional[Dict[str, Any]]:
333333

334334
def add_health_check(self):
335335
"""Add a health check."""
336-
mgmt_router = APIRouter()
336+
mgmt_router = APIRouter(prefix=self.app.state.router_prefix)
337337

338338
@mgmt_router.get("/_mgmt/ping")
339339
async def ping():
@@ -381,6 +381,10 @@ def __attrs_post_init__(self):
381381
self.register_core()
382382
self.app.include_router(self.router)
383383

384+
# keep link to the router prefix value
385+
router_prefix = self.router.prefix
386+
self.app.state.router_prefix = router_prefix if router_prefix else ""
387+
384388
# register extensions
385389
for ext in self.extensions:
386390
ext.register(self.app)

stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def register(self, app: FastAPI) -> None:
105105
Returns:
106106
None
107107
"""
108+
self.router.prefix = app.state.router_prefix
108109
self.router.add_api_route(
109110
name="Queryables",
110111
path="/queryables",

stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def register(self, app: FastAPI) -> None:
160160
Returns:
161161
None
162162
"""
163+
self.router.prefix = app.state.router_prefix
163164
self.register_create_item()
164165
self.register_update_item()
165166
self.register_delete_item()

stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def register(self, app: FastAPI) -> None:
116116
"""
117117
items_request_model = create_request_model("Items", base_model=Items)
118118

119-
router = APIRouter()
119+
router = APIRouter(prefix=app.state.router_prefix)
120120
router.add_api_route(
121121
name="Bulk Create Item",
122122
path="/collections/{collection_id}/bulk_items",

stac_fastapi/pgstac/stac_fastapi/pgstac/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from stac_fastapi.pgstac.utils import filter_fields
2424
from stac_fastapi.types.core import AsyncBaseCoreClient
2525
from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError
26+
from stac_fastapi.types.requests import get_base_url
2627
from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection
2728

2829
NumType = Union[float, int]
@@ -35,7 +36,7 @@ class CoreCrudClient(AsyncBaseCoreClient):
3536
async def all_collections(self, **kwargs) -> Collections:
3637
"""Read all collections from the database."""
3738
request: Request = kwargs["request"]
38-
base_url = str(request.base_url)
39+
base_url = get_base_url(request)
3940
pool = request.app.state.readpool
4041

4142
async with pool.acquire() as conn:

stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from stac_pydantic.shared import MimeTypes
99
from starlette.requests import Request
1010

11+
from stac_fastapi.types.requests import get_base_url
12+
1113
# These can be inferred from the item/collection so they aren't included in the database
1214
# Instead they are dynamically generated when querying the database using the classes defined below
1315
INFERRED_LINK_RELS = ["self", "item", "parent", "collection", "root"]
@@ -45,7 +47,7 @@ class BaseLinks:
4547
@property
4648
def base_url(self):
4749
"""Get the base url."""
48-
return str(self.request.base_url)
50+
return get_base_url(self.request)
4951

5052
@property
5153
def url(self):

stac_fastapi/pgstac/tests/api/test_api.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,24 @@ async def test_api_headers(app_client):
4949
assert resp.status_code == 200
5050

5151

52-
async def test_core_router(api_client):
53-
core_routes = set(STAC_CORE_ROUTES)
52+
async def test_core_router(api_client, app):
53+
core_routes = set()
54+
for core_route in STAC_CORE_ROUTES:
55+
method, path = core_route.split(" ")
56+
core_routes.add("{} {}".format(method, app.state.router_prefix + path))
57+
5458
api_routes = set(
5559
[f"{list(route.methods)[0]} {route.path}" for route in api_client.app.routes]
5660
)
5761
assert not core_routes - api_routes
5862

5963

60-
async def test_transactions_router(api_client):
61-
transaction_routes = set(STAC_TRANSACTION_ROUTES)
64+
async def test_transactions_router(api_client, app):
65+
transaction_routes = set()
66+
for transaction_route in STAC_TRANSACTION_ROUTES:
67+
method, path = transaction_route.split(" ")
68+
transaction_routes.add("{} {}".format(method, app.state.router_prefix + path))
69+
6270
api_routes = set(
6371
[f"{list(route.methods)[0]} {route.path}" for route in api_client.app.routes]
6472
)

stac_fastapi/pgstac/tests/conftest.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import os
44
import time
55
from typing import Callable, Dict
6+
from urllib.parse import urljoin
67

78
import asyncpg
89
import pytest
10+
from fastapi import APIRouter
911
from fastapi.responses import ORJSONResponse
1012
from httpx import AsyncClient
1113
from pypgstac.db import PgstacDB
@@ -107,9 +109,26 @@ async def pgstac(pg):
107109

108110

109111
# Run all the tests that use the api_client in both db hydrate and api hydrate mode
110-
@pytest.fixture(params=[settings, pgstac_api_hydrate_settings], scope="session")
112+
@pytest.fixture(
113+
params=[
114+
(settings, ""),
115+
(settings, "/router_prefix"),
116+
(pgstac_api_hydrate_settings, ""),
117+
(pgstac_api_hydrate_settings, "/router_prefix"),
118+
],
119+
scope="session",
120+
)
111121
def api_client(request, pg):
112-
print("creating client with settings, hydrate:", request.param.use_api_hydrate)
122+
api_settings, prefix = request.param
123+
124+
api_settings.openapi_url = prefix + api_settings.openapi_url
125+
api_settings.docs_url = prefix + api_settings.docs_url
126+
127+
print(
128+
"creating client with settings, hydrate: {}, router prefix: '{}'".format(
129+
api_settings.use_api_hydrate, prefix
130+
)
131+
)
113132

114133
extensions = [
115134
TransactionExtension(client=TransactionsClient(), settings=settings),
@@ -122,12 +141,13 @@ def api_client(request, pg):
122141
]
123142
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
124143
api = StacApi(
125-
settings=request.param,
144+
settings=api_settings,
126145
extensions=extensions,
127146
client=CoreCrudClient(post_request_model=post_request_model),
128147
search_get_request_model=create_get_request_model(extensions),
129148
search_post_request_model=post_request_model,
130149
response_class=ORJSONResponse,
150+
router=APIRouter(prefix=prefix),
131151
)
132152

133153
return api
@@ -150,7 +170,12 @@ async def app(api_client):
150170
@pytest.fixture(scope="function")
151171
async def app_client(app):
152172
print("creating app_client")
153-
async with AsyncClient(app=app, base_url="http://test") as c:
173+
174+
base_url = "http://test"
175+
if app.state.router_prefix != "":
176+
base_url = urljoin(base_url, app.state.router_prefix)
177+
178+
async with AsyncClient(app=app, base_url=base_url) as c:
154179
yield c
155180

156181

stac_fastapi/pgstac/tests/resources/test_conformance.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,25 +46,25 @@ def test_landing_page_health(response):
4646

4747
@pytest.mark.parametrize("rel_type,expected_media_type,expected_path", link_tests)
4848
async def test_landing_page_links(
49-
response_json: Dict, app_client, rel_type, expected_media_type, expected_path
49+
response_json: Dict, app_client, app, rel_type, expected_media_type, expected_path
5050
):
5151
link = get_link(response_json, rel_type)
5252

5353
assert link is not None, f"Missing {rel_type} link in landing page"
5454
assert link.get("type") == expected_media_type
5555

5656
link_path = urllib.parse.urlsplit(link.get("href")).path
57-
assert link_path == expected_path
57+
assert link_path == app.state.router_prefix + expected_path
5858

59-
resp = await app_client.get(link_path)
59+
resp = await app_client.get(link_path.rsplit("/", 1)[-1])
6060
assert resp.status_code == 200
6161

6262

6363
# This endpoint currently returns a 404 for empty result sets, but testing for this response
6464
# code here seems meaningless since it would be the same as if the endpoint did not exist. Once
6565
# https://github.com/stac-utils/stac-fastapi/pull/227 has been merged we can add this to the
6666
# parameterized tests above.
67-
def test_search_link(response_json: Dict):
67+
def test_search_link(response_json: Dict, app):
6868
for search_link in [
6969
get_link(response_json, "search", "GET"),
7070
get_link(response_json, "search", "POST"),
@@ -73,4 +73,4 @@ def test_search_link(response_json: Dict):
7373
assert search_link.get("type") == "application/geo+json"
7474

7575
search_path = urllib.parse.urlsplit(search_link.get("href")).path
76-
assert search_path == "/search"
76+
assert search_path == app.state.router_prefix + "/search"

stac_fastapi/pgstac/tests/resources/test_item.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,7 +1164,7 @@ async def test_get_missing_item(app_client, load_test_data):
11641164
assert resp.status_code == 404
11651165

11661166

1167-
async def test_relative_link_construction():
1167+
async def test_relative_link_construction(app):
11681168
req = Request(
11691169
scope={
11701170
"type": "http",
@@ -1175,10 +1175,13 @@ async def test_relative_link_construction():
11751175
"raw_path": b"/tab/abc",
11761176
"query_string": b"",
11771177
"headers": {},
1178+
"app": app,
11781179
}
11791180
)
11801181
links = CollectionLinks(collection_id="naip", request=req)
1181-
assert links.link_items()["href"] == "http://test/stac/collections/naip/items"
1182+
assert links.link_items()["href"] == (
1183+
"http://test/stac{}/collections/naip/items".format(app.state.router_prefix)
1184+
)
11821185

11831186

11841187
async def test_search_bbox_errors(app_client):

stac_fastapi/types/stac_fastapi/types/core.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from stac_fastapi.types import stac as stac_types
1515
from stac_fastapi.types.conformance import BASE_CONFORMANCE_CLASSES
1616
from stac_fastapi.types.extension import ApiExtension
17+
from stac_fastapi.types.requests import get_base_url
1718
from stac_fastapi.types.search import BaseSearchPostRequest
1819
from stac_fastapi.types.stac import Conformance
1920

@@ -349,7 +350,7 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage:
349350
API landing page, serving as an entry point to the API.
350351
"""
351352
request: Request = kwargs["request"]
352-
base_url = str(request.base_url)
353+
base_url = get_base_url(request)
353354
extension_schemas = [
354355
schema.schema_href for schema in self.extensions if schema.schema_href
355356
]
@@ -377,7 +378,9 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage:
377378
"rel": "service-desc",
378379
"type": "application/vnd.oai.openapi+json;version=3.0",
379380
"title": "OpenAPI service description",
380-
"href": urljoin(base_url, request.app.openapi_url.lstrip("/")),
381+
"href": urljoin(
382+
str(request.base_url), request.app.openapi_url.lstrip("/")
383+
),
381384
}
382385
)
383386

@@ -387,7 +390,9 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage:
387390
"rel": "service-doc",
388391
"type": "text/html",
389392
"title": "OpenAPI service documentation",
390-
"href": urljoin(base_url, request.app.docs_url.lstrip("/")),
393+
"href": urljoin(
394+
str(request.base_url), request.app.docs_url.lstrip("/")
395+
),
391396
}
392397
)
393398

@@ -538,7 +543,7 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage:
538543
API landing page, serving as an entry point to the API.
539544
"""
540545
request: Request = kwargs["request"]
541-
base_url = str(request.base_url)
546+
base_url = get_base_url(request)
542547
extension_schemas = [
543548
schema.schema_href for schema in self.extensions if schema.schema_href
544549
]
@@ -564,7 +569,9 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage:
564569
"rel": "service-desc",
565570
"type": "application/vnd.oai.openapi+json;version=3.0",
566571
"title": "OpenAPI service description",
567-
"href": urljoin(base_url, request.app.openapi_url.lstrip("/")),
572+
"href": urljoin(
573+
str(request.base_url), request.app.openapi_url.lstrip("/")
574+
),
568575
}
569576
)
570577

@@ -574,7 +581,9 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage:
574581
"rel": "service-doc",
575582
"type": "text/html",
576583
"title": "OpenAPI service documentation",
577-
"href": urljoin(base_url, request.app.docs_url.lstrip("/")),
584+
"href": urljoin(
585+
str(request.base_url), request.app.docs_url.lstrip("/")
586+
),
578587
}
579588
)
580589

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""requests helpers."""
2+
3+
from starlette.requests import Request
4+
5+
6+
def get_base_url(request: Request) -> str:
7+
"""Get base URL with respect of APIRouter prefix."""
8+
app = request.app
9+
if not app.state.router_prefix:
10+
return str(request.base_url)
11+
else:
12+
return "{}{}/".format(
13+
str(request.base_url), app.state.router_prefix.lstrip("/")
14+
)

0 commit comments

Comments
 (0)