Skip to content

Commit 0ac6b61

Browse files
dependabot[bot]hrodmn
authored andcommitted
add collection-search extension
1 parent 4bb5e0f commit 0ac6b61

File tree

8 files changed

+260
-80
lines changed

8 files changed

+260
-80
lines changed

.dockerignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,7 @@ coverage.xml
99
*.log
1010
.git
1111
.envrc
12+
*egg-info
1213

13-
venv
14+
venv
15+
env

.github/workflows/cicd.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
runs-on: ubuntu-latest
4848
services:
4949
pgstac:
50-
image: ghcr.io/stac-utils/pgstac:v0.7.10
50+
image: ghcr.io/stac-utils/pgstac:v0.8.6
5151
env:
5252
POSTGRES_USER: username
5353
POSTGRES_PASSWORD: password

CHANGES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## [Unreleased]
44

5+
- Add collection search extension
6+
57
## [3.0.0] - 2024-08-02
68

79
- Enable filter extension for `GET /items` requests and add `Queryables` links in `/collections` and `/collections/{collection_id}` responses ([#89](https://github.com/stac-utils/stac-fastapi-pgstac/pull/89))

stac_fastapi/pgstac/app.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
from fastapi.responses import ORJSONResponse
1111
from stac_fastapi.api.app import StacApi
1212
from stac_fastapi.api.models import (
13+
EmptyRequest,
1314
ItemCollectionUri,
1415
create_get_request_model,
1516
create_post_request_model,
1617
create_request_model,
1718
)
1819
from stac_fastapi.extensions.core import (
20+
CollectionSearchExtension,
1921
FieldsExtension,
2022
FilterExtension,
2123
SortExtension,
@@ -47,12 +49,26 @@
4749
"bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()),
4850
}
4951

52+
collections_extensions_map = {
53+
"collection_search": CollectionSearchExtension(),
54+
}
55+
5056
if enabled_extensions := os.getenv("ENABLED_EXTENSIONS"):
57+
_enabled_extensions = enabled_extensions.split(",")
5158
extensions = [
52-
extensions_map[extension_name] for extension_name in enabled_extensions.split(",")
59+
extension
60+
for key, extension in extensions_map.items()
61+
if key in _enabled_extensions
62+
]
63+
collection_extensions = [
64+
extension
65+
for key, extension in collections_extensions_map.items()
66+
if key in _enabled_extensions
5367
]
5468
else:
5569
extensions = list(extensions_map.values())
70+
collection_extensions = list(collections_extensions_map.values())
71+
5672

5773
if any(isinstance(ext, TokenPaginationExtension) for ext in extensions):
5874
items_get_request_model = create_request_model(
@@ -64,12 +80,19 @@
6480
else:
6581
items_get_request_model = ItemCollectionUri
6682

67-
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
68-
get_request_model = create_get_request_model(extensions)
83+
if any(isinstance(ext, CollectionSearchExtension) for ext in collection_extensions):
84+
collections_get_request_model = CollectionSearchExtension().GET
85+
else:
86+
collections_get_request_model = EmptyRequest
87+
88+
post_request_model = create_post_request_model(
89+
extensions + collection_extensions, base_model=PgstacSearch
90+
)
91+
get_request_model = create_get_request_model(extensions + collection_extensions)
6992

7093
api = StacApi(
7194
settings=settings,
72-
extensions=extensions,
95+
extensions=extensions + collection_extensions,
7396
client=CoreCrudClient(post_request_model=post_request_model), # type: ignore
7497
response_class=ORJSONResponse,
7598
items_get_request_model=items_get_request_model,

stac_fastapi/pgstac/core.py

Lines changed: 169 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,11 @@
1414
from pygeofilter.parsers.cql2_text import parse as parse_cql2_text
1515
from pypgstac.hydration import hydrate
1616
from stac_fastapi.api.models import JSONResponse
17-
from stac_fastapi.types.core import AsyncBaseCoreClient
17+
from stac_fastapi.types.core import AsyncBaseCoreClient, Relations
1818
from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError
1919
from stac_fastapi.types.requests import get_base_url
2020
from stac_fastapi.types.rfc3339 import DateTimeType
2121
from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection
22-
from stac_pydantic.links import Relations
2322
from stac_pydantic.shared import BBox, MimeTypes
2423

2524
from stac_fastapi.pgstac.config import Settings
@@ -39,17 +38,100 @@
3938
class CoreCrudClient(AsyncBaseCoreClient):
4039
"""Client for core endpoints defined by stac."""
4140

42-
async def all_collections(self, request: Request, **kwargs) -> Collections:
43-
"""Read all collections from the database."""
41+
async def all_collections( # noqa: C901
42+
self,
43+
request: Request,
44+
# Extensions
45+
bbox: Optional[BBox] = None,
46+
datetime: Optional[DateTimeType] = None,
47+
limit: Optional[int] = None,
48+
query: Optional[str] = None,
49+
token: Optional[str] = None,
50+
fields: Optional[List[str]] = None,
51+
sortby: Optional[str] = None,
52+
filter: Optional[str] = None,
53+
filter_lang: Optional[str] = None,
54+
**kwargs,
55+
) -> Collections:
56+
"""Cross catalog search (GET).
57+
58+
Called with `GET /collections`.
59+
60+
Returns:
61+
Collections which match the search criteria, returns all
62+
collections by default.
63+
"""
64+
65+
# Parse request parameters
66+
base_args = {
67+
"bbox": bbox,
68+
"limit": limit,
69+
"token": token,
70+
"query": orjson.loads(unquote_plus(query)) if query else query,
71+
}
72+
73+
clean = clean_search_args(
74+
base_args=base_args,
75+
datetime=datetime,
76+
fields=fields,
77+
sortby=sortby,
78+
filter=filter,
79+
filter_lang=filter_lang,
80+
)
81+
82+
# Do the request
83+
try:
84+
search_request = self.post_request_model(**clean)
85+
except ValidationError as e:
86+
raise HTTPException(
87+
status_code=400, detail=f"Invalid parameters provided {e}"
88+
) from e
89+
90+
return await self._collection_search_base(search_request, request=request)
91+
92+
async def _collection_search_base( # noqa: C901
93+
self,
94+
search_request: PgstacSearch,
95+
request: Request,
96+
) -> Collections:
97+
"""Cross catalog search (GET).
98+
99+
Called with `GET /search`.
100+
101+
Args:
102+
search_request: search request parameters.
103+
104+
Returns:
105+
All collections which match the search criteria.
106+
"""
44107
base_url = get_base_url(request)
108+
search_request_json = search_request.model_dump_json(
109+
exclude_none=True, by_alias=True
110+
)
111+
112+
try:
113+
async with request.app.state.get_connection(request, "r") as conn:
114+
q, p = render(
115+
"""
116+
SELECT * FROM collection_search(:req::text::jsonb);
117+
""",
118+
req=search_request_json,
119+
)
120+
collections_result: Collections = await conn.fetchval(q, *p)
121+
except InvalidDatetimeFormatError as e:
122+
raise InvalidQueryParameter(
123+
f"Datetime parameter {search_request.datetime} is invalid."
124+
) from e
125+
126+
next: Optional[str] = None
127+
prev: Optional[str] = None
128+
129+
if links := collections_result.get("links"):
130+
next = collections_result["links"].pop("next")
131+
prev = collections_result["links"].pop("prev")
45132

46-
async with request.app.state.get_connection(request, "r") as conn:
47-
collections = await conn.fetchval(
48-
"""
49-
SELECT * FROM all_collections();
50-
"""
51-
)
52133
linked_collections: List[Collection] = []
134+
collections = collections_result["collections"]
53135
if collections is not None and len(collections) > 0:
54136
for c in collections:
55137
coll = Collection(**c)
@@ -71,25 +153,16 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:
71153

72154
linked_collections.append(coll)
73155

74-
links = [
75-
{
76-
"rel": Relations.root.value,
77-
"type": MimeTypes.json,
78-
"href": base_url,
79-
},
80-
{
81-
"rel": Relations.parent.value,
82-
"type": MimeTypes.json,
83-
"href": base_url,
84-
},
85-
{
86-
"rel": Relations.self.value,
87-
"type": MimeTypes.json,
88-
"href": urljoin(base_url, "collections"),
89-
},
90-
]
91-
collection_list = Collections(collections=linked_collections or [], links=links)
92-
return collection_list
156+
links = await PagingLinks(
157+
request=request,
158+
next=next,
159+
prev=prev,
160+
).get_links()
161+
162+
return Collections(
163+
collections=linked_collections or [],
164+
links=links,
165+
)
93166

94167
async def get_collection(
95168
self, collection_id: str, request: Request, **kwargs
@@ -383,7 +456,7 @@ async def post_search(
383456

384457
return ItemCollection(**item_collection)
385458

386-
async def get_search( # noqa: C901
459+
async def get_search(
387460
self,
388461
request: Request,
389462
collections: Optional[List[str]] = None,
@@ -418,49 +491,15 @@ async def get_search( # noqa: C901
418491
"query": orjson.loads(unquote_plus(query)) if query else query,
419492
}
420493

421-
if filter:
422-
if filter_lang == "cql2-text":
423-
ast = parse_cql2_text(filter)
424-
base_args["filter"] = orjson.loads(to_cql2(ast))
425-
base_args["filter-lang"] = "cql2-json"
426-
427-
if datetime:
428-
base_args["datetime"] = format_datetime_range(datetime)
429-
430-
if intersects:
431-
base_args["intersects"] = orjson.loads(unquote_plus(intersects))
432-
433-
if sortby:
434-
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
435-
sort_param = []
436-
for sort in sortby:
437-
sortparts = re.match(r"^([+-]?)(.*)$", sort)
438-
if sortparts:
439-
sort_param.append(
440-
{
441-
"field": sortparts.group(2).strip(),
442-
"direction": "desc" if sortparts.group(1) == "-" else "asc",
443-
}
444-
)
445-
base_args["sortby"] = sort_param
446-
447-
if fields:
448-
includes = set()
449-
excludes = set()
450-
for field in fields:
451-
if field[0] == "-":
452-
excludes.add(field[1:])
453-
elif field[0] == "+":
454-
includes.add(field[1:])
455-
else:
456-
includes.add(field)
457-
base_args["fields"] = {"include": includes, "exclude": excludes}
458-
459-
# Remove None values from dict
460-
clean = {}
461-
for k, v in base_args.items():
462-
if v is not None and v != []:
463-
clean[k] = v
494+
clean = clean_search_args(
495+
base_args=base_args,
496+
intersects=intersects,
497+
datetime=datetime,
498+
fields=fields,
499+
sortby=sortby,
500+
filter=filter,
501+
filter_lang=filter_lang,
502+
)
464503

465504
# Do the request
466505
try:
@@ -471,3 +510,60 @@ async def get_search( # noqa: C901
471510
) from e
472511

473512
return await self.post_search(search_request, request=request)
513+
514+
515+
def clean_search_args( # noqa: C901
516+
base_args: Dict[str, Any],
517+
intersects: Optional[str] = None,
518+
datetime: Optional[DateTimeType] = None,
519+
fields: Optional[List[str]] = None,
520+
sortby: Optional[str] = None,
521+
filter: Optional[str] = None,
522+
filter_lang: Optional[str] = None,
523+
) -> Dict[str, Any]:
524+
"""Clean up search arguments to match format expected by pgstac"""
525+
if filter:
526+
if filter_lang == "cql2-text":
527+
ast = parse_cql2_text(filter)
528+
base_args["filter"] = orjson.loads(to_cql2(ast))
529+
base_args["filter-lang"] = "cql2-json"
530+
531+
if datetime:
532+
base_args["datetime"] = format_datetime_range(datetime)
533+
534+
if intersects:
535+
base_args["intersects"] = orjson.loads(unquote_plus(intersects))
536+
537+
if sortby:
538+
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
539+
sort_param = []
540+
for sort in sortby:
541+
sortparts = re.match(r"^([+-]?)(.*)$", sort)
542+
if sortparts:
543+
sort_param.append(
544+
{
545+
"field": sortparts.group(2).strip(),
546+
"direction": "desc" if sortparts.group(1) == "-" else "asc",
547+
}
548+
)
549+
base_args["sortby"] = sort_param
550+
551+
if fields:
552+
includes = set()
553+
excludes = set()
554+
for field in fields:
555+
if field[0] == "-":
556+
excludes.add(field[1:])
557+
elif field[0] == "+":
558+
includes.add(field[1:])
559+
else:
560+
includes.add(field)
561+
base_args["fields"] = {"include": includes, "exclude": excludes}
562+
563+
# Remove None values from dict
564+
clean = {}
565+
for k, v in base_args.items():
566+
if v is not None and v != []:
567+
clean[k] = v
568+
569+
return clean

0 commit comments

Comments
 (0)