Skip to content

Commit 6cef4d4

Browse files
authored
Allow custom models for search GET and items endpoints (#271)
* Allow ItemCollectionUri and SearchGetRequest models to be overridden This allows setting a different default limit * Enable usage of custom search model in pgstac * Update CHANGES.md
1 parent 97e0439 commit 6cef4d4

File tree

3 files changed

+18
-8
lines changed

3 files changed

+18
-8
lines changed

CHANGES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
### Added
66

7+
* Add ability to override ItemCollectionUri and SearchGetRequest models ([#271](https://github.com/stac-utils/stac-fastapi/pull/271))
78
* Added `collections` attribute to list of default fields to include, so that we satisfy the STAC API spec, which requires a `collections` attribute to be output when an item is part of a collection
9+
810
### Removed
911

1012
### Changed

stac_fastapi/api/stac_fastapi/api/app.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ class StacApi:
7373
stac_version: str = attr.ib(default=STAC_VERSION)
7474
description: str = attr.ib(default="stac-fastapi")
7575
search_request_model: Type[Search] = attr.ib(default=STACSearch)
76+
search_get_request: Type[SearchGetRequest] = attr.ib(default=SearchGetRequest)
77+
item_collection_uri: Type[ItemCollectionUri] = attr.ib(default=ItemCollectionUri)
7678
response_class: Type[Response] = attr.ib(default=JSONResponse)
7779
middlewares: List = attr.ib(default=attr.Factory(lambda: [BrotliMiddleware]))
7880

@@ -199,7 +201,9 @@ def register_get_search(self):
199201
response_model_exclude_unset=True,
200202
response_model_exclude_none=True,
201203
methods=["GET"],
202-
endpoint=self._create_endpoint(self.client.get_search, SearchGetRequest),
204+
endpoint=self._create_endpoint(
205+
self.client.get_search, self.search_get_request
206+
),
203207
)
204208

205209
def register_get_collections(self):
@@ -255,7 +259,7 @@ def register_get_item_collection(self):
255259
response_model_exclude_none=True,
256260
methods=["GET"],
257261
endpoint=self._create_endpoint(
258-
self.client.item_collection, ItemCollectionUri
262+
self.client.item_collection, self.item_collection_uri
259263
),
260264
)
261265

stac_fastapi/pgstac/stac_fastapi/pgstac/core.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Item crud client."""
22
import re
33
from datetime import datetime
4-
from typing import Any, Dict, List, Optional, Union
4+
from typing import Any, Dict, List, Optional, Type, Union
55
from urllib.parse import urljoin
66

77
import attr
@@ -27,6 +27,8 @@
2727
class CoreCrudClient(AsyncBaseCoreClient):
2828
"""Client for core endpoints defined by stac."""
2929

30+
search_request_model: Type[PgstacSearch] = attr.ib(init=False, default=PgstacSearch)
31+
3032
async def all_collections(self, **kwargs) -> Collections:
3133
"""Read all collections from the database."""
3234
request: Request = kwargs["request"]
@@ -168,7 +170,7 @@ async def _search_base(
168170
return collection
169171

170172
async def item_collection(
171-
self, id: str, limit: int = 10, token: str = None, **kwargs
173+
self, id: str, limit: Optional[int] = None, token: str = None, **kwargs
172174
) -> ItemCollection:
173175
"""Get all items from a specific collection.
174176
@@ -185,7 +187,7 @@ async def item_collection(
185187
# If collection does not exist, NotFoundError wil be raised
186188
await self.get_collection(id, **kwargs)
187189

188-
req = PgstacSearch(collections=[id], limit=limit, token=token)
190+
req = self.search_request_model(collections=[id], limit=limit, token=token)
189191
item_collection = await self._search_base(req, **kwargs)
190192
links = await CollectionLinks(
191193
collection_id=id, request=kwargs["request"]
@@ -207,7 +209,9 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
207209
# If collection does not exist, NotFoundError wil be raised
208210
await self.get_collection(collection_id, **kwargs)
209211

210-
req = PgstacSearch(ids=[item_id], collections=[collection_id], limit=1)
212+
req = self.search_request_model(
213+
ids=[item_id], collections=[collection_id], limit=1
214+
)
211215
item_collection = await self._search_base(req, **kwargs)
212216
if not item_collection["features"]:
213217
raise NotFoundError(
@@ -238,7 +242,7 @@ async def get_search(
238242
ids: Optional[List[str]] = None,
239243
bbox: Optional[List[NumType]] = None,
240244
datetime: Optional[Union[str, datetime]] = None,
241-
limit: Optional[int] = 10,
245+
limit: Optional[int] = None,
242246
query: Optional[str] = None,
243247
token: Optional[str] = None,
244248
fields: Optional[List[str]] = None,
@@ -292,7 +296,7 @@ async def get_search(
292296

293297
# Do the request
294298
try:
295-
search_request = PgstacSearch(**base_args)
299+
search_request = self.search_request_model(**base_args)
296300
except ValidationError:
297301
raise HTTPException(status_code=400, detail="Invalid parameters provided")
298302
return await self.post_search(search_request, request=kwargs["request"])

0 commit comments

Comments
 (0)