Skip to content

Commit 1cd6833

Browse files
committed
add collection search extension
1 parent 1f36485 commit 1cd6833

File tree

8 files changed

+273
-61
lines changed

8 files changed

+273
-61
lines changed

.dockerignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,7 @@ coverage.xml
99
*.log
1010
.git
1111
.envrc
12+
*egg-info
1213

13-
venv
14+
venv
15+
env

CHANGES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22

33
## [Unreleased]
44

5+
<<<<<<< HEAD
56
- Enable filter extension for `GET /items` requests and add `Queryables` links in `/collections` and `/collections/{collection_id}` responses ([#89](https://github.com/stac-utils/stac-fastapi-pgstac/pull/89))
7+
=======
8+
- Add collection search extension
9+
>>>>>>> 9b2050a (add collection search extension)
610
711
## [3.0.0a4] - 2024-07-10
812

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
"orjson",
1111
"pydantic",
1212
"stac_pydantic==3.1.*",
13-
"stac-fastapi.api~=3.0.0b2",
14-
"stac-fastapi.extensions~=3.0.0b2",
15-
"stac-fastapi.types~=3.0.0b2",
13+
"stac-fastapi.api~=3.0.0b3",
14+
"stac-fastapi.extensions~=3.0.0b3",
15+
"stac-fastapi.types~=3.0.0b3",
1616
"asyncpg",
1717
"buildpg",
1818
"brotli_asgi",

stac_fastapi/pgstac/app.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
from fastapi.responses import ORJSONResponse
1111
from stac_fastapi.api.app import StacApi
1212
from stac_fastapi.api.models import (
13+
EmptyRequest,
1314
ItemCollectionUri,
1415
create_get_request_model,
1516
create_post_request_model,
1617
create_request_model,
1718
)
1819
from stac_fastapi.extensions.core import (
20+
CollectionSearchExtension,
1921
FieldsExtension,
2022
FilterExtension,
2123
SortExtension,
@@ -47,12 +49,26 @@
4749
"bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()),
4850
}
4951

52+
collections_extensions_map = {
53+
"collection_search": CollectionSearchExtension(),
54+
}
55+
5056
if enabled_extensions := os.getenv("ENABLED_EXTENSIONS"):
57+
_enabled_extensions = enabled_extensions.split(",")
5158
extensions = [
52-
extensions_map[extension_name] for extension_name in enabled_extensions.split(",")
59+
extension
60+
for key, extension in extensions_map.items()
61+
if key in _enabled_extensions
62+
]
63+
collection_extensions = [
64+
extension
65+
for key, extension in collections_extensions_map.items()
66+
if key in _enabled_extensions
5367
]
5468
else:
5569
extensions = list(extensions_map.values())
70+
collection_extensions = list(collections_extensions_map.values())
71+
5672

5773
if any(isinstance(ext, TokenPaginationExtension) for ext in extensions):
5874
items_get_request_model = create_request_model(
@@ -64,17 +80,23 @@
6480
else:
6581
items_get_request_model = ItemCollectionUri
6682

83+
if any(isinstance(ext, CollectionSearchExtension) for ext in collection_extensions):
84+
collections_get_request_model = CollectionSearchExtension().GET
85+
else:
86+
collections_get_request_model = EmptyRequest
87+
6788
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
6889
get_request_model = create_get_request_model(extensions)
6990

7091
api = StacApi(
7192
settings=settings,
72-
extensions=extensions,
93+
extensions=extensions + collection_extensions,
7394
client=CoreCrudClient(post_request_model=post_request_model), # type: ignore
7495
response_class=ORJSONResponse,
7596
items_get_request_model=items_get_request_model,
7697
search_get_request_model=get_request_model,
7798
search_post_request_model=post_request_model,
99+
collections_get_request_model=collections_get_request_model,
78100
)
79101
app = api.app
80102

stac_fastapi/pgstac/core.py

Lines changed: 181 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,113 @@
3939
class CoreCrudClient(AsyncBaseCoreClient):
4040
"""Client for core endpoints defined by stac."""
4141

42-
async def all_collections(self, request: Request, **kwargs) -> Collections:
43-
"""Read all collections from the database."""
42+
async def all_collections( # noqa: C901
43+
self,
44+
request: Request,
45+
bbox: Optional[BBox] = None,
46+
datetime: Optional[DateTimeType] = None,
47+
limit: Optional[int] = None,
48+
# Extensions
49+
query: Optional[str] = None,
50+
token: Optional[str] = None,
51+
fields: Optional[List[str]] = None,
52+
sortby: Optional[str] = None,
53+
filter: Optional[str] = None,
54+
filter_lang: Optional[str] = None,
55+
**kwargs,
56+
) -> Collections:
57+
"""Cross catalog search (GET).
58+
59+
Called with `GET /collections`.
60+
61+
Returns:
62+
Collections which match the search criteria, returns all
63+
collections by default.
64+
"""
65+
query_params = str(request.query_params)
66+
67+
# Kludgy fix because using factory does not allow alias for filter-lang
68+
if filter_lang is None:
69+
match = re.search(r"filter-lang=([a-z0-9-]+)", query_params, re.IGNORECASE)
70+
if match:
71+
filter_lang = match.group(1)
72+
73+
# Parse request parameters
74+
base_args = {
75+
"bbox": bbox,
76+
"limit": limit,
77+
"token": token,
78+
"query": orjson.loads(unquote_plus(query)) if query else query,
79+
}
80+
81+
clean = clean_search_args(
82+
base_args=base_args,
83+
datetime=datetime,
84+
fields=fields,
85+
sortby=sortby,
86+
filter=filter,
87+
filter_lang=filter_lang,
88+
)
89+
90+
# Do the request
91+
try:
92+
search_request = self.post_request_model(**clean)
93+
except ValidationError as e:
94+
raise HTTPException(
95+
status_code=400, detail=f"Invalid parameters provided {e}"
96+
) from e
97+
98+
return await self._collection_search_base(search_request, request=request)
99+
100+
async def _collection_search_base( # noqa: C901
101+
self,
102+
search_request: PgstacSearch,
103+
request: Request,
104+
) -> Collections:
105+
"""Cross catalog search (POST).
106+
107+
Called with `POST /search`.
108+
109+
Args:
110+
search_request: search request parameters.
111+
112+
Returns:
113+
All collections which match the search criteria.
114+
"""
115+
44116
base_url = get_base_url(request)
45117

46-
async with request.app.state.get_connection(request, "r") as conn:
47-
collections = await conn.fetchval(
48-
"""
49-
SELECT * FROM all_collections();
50-
"""
51-
)
118+
settings: Settings = request.app.state.settings
119+
120+
if search_request.datetime:
121+
search_request.datetime = format_datetime_range(search_request.datetime)
122+
123+
search_request.conf = search_request.conf or {}
124+
search_request.conf["nohydrate"] = settings.use_api_hydrate
125+
126+
search_request_json = search_request.model_dump_json(
127+
exclude_none=True, by_alias=True
128+
)
129+
130+
try:
131+
async with request.app.state.get_connection(request, "r") as conn:
132+
q, p = render(
133+
"""
134+
SELECT * FROM collection_search(:req::text::jsonb);
135+
""",
136+
req=search_request_json,
137+
)
138+
collections_result: Collections = await conn.fetchval(q, *p)
139+
except InvalidDatetimeFormatError as e:
140+
raise InvalidQueryParameter(
141+
f"Datetime parameter {search_request.datetime} is invalid."
142+
) from e
143+
144+
# next: Optional[str] = collections_result["links"].pop("next")
145+
# prev: Optional[str] = collections_result["links"].pop("prev")
146+
52147
linked_collections: List[Collection] = []
148+
collections = collections_result["collections"]
53149
if collections is not None and len(collections) > 0:
54150
for c in collections:
55151
coll = Collection(**c)
@@ -71,6 +167,12 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:
71167

72168
linked_collections.append(coll)
73169

170+
# paging_links = await PagingLinks(
171+
# request=request,
172+
# next=next,
173+
# prev=prev,
174+
# ).get_links()
175+
74176
links = [
75177
{
76178
"rel": Relations.root.value,
@@ -88,8 +190,10 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:
88190
"href": urljoin(base_url, "collections"),
89191
},
90192
]
91-
collection_list = Collections(collections=linked_collections or [], links=links)
92-
return collection_list
193+
return Collections(
194+
collections=linked_collections or [],
195+
links=links, # + paging_links
196+
)
93197

94198
async def get_collection(
95199
self, collection_id: str, request: Request, **kwargs
@@ -383,7 +487,7 @@ async def post_search(
383487

384488
return ItemCollection(**item_collection)
385489

386-
async def get_search( # noqa: C901
490+
async def get_search(
387491
self,
388492
request: Request,
389493
collections: Optional[List[str]] = None,
@@ -418,49 +522,15 @@ async def get_search( # noqa: C901
418522
"query": orjson.loads(unquote_plus(query)) if query else query,
419523
}
420524

421-
if filter:
422-
if filter_lang == "cql2-text":
423-
ast = parse_cql2_text(filter)
424-
base_args["filter"] = orjson.loads(to_cql2(ast))
425-
base_args["filter-lang"] = "cql2-json"
426-
427-
if datetime:
428-
base_args["datetime"] = format_datetime_range(datetime)
429-
430-
if intersects:
431-
base_args["intersects"] = orjson.loads(unquote_plus(intersects))
432-
433-
if sortby:
434-
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
435-
sort_param = []
436-
for sort in sortby:
437-
sortparts = re.match(r"^([+-]?)(.*)$", sort)
438-
if sortparts:
439-
sort_param.append(
440-
{
441-
"field": sortparts.group(2).strip(),
442-
"direction": "desc" if sortparts.group(1) == "-" else "asc",
443-
}
444-
)
445-
base_args["sortby"] = sort_param
446-
447-
if fields:
448-
includes = set()
449-
excludes = set()
450-
for field in fields:
451-
if field[0] == "-":
452-
excludes.add(field[1:])
453-
elif field[0] == "+":
454-
includes.add(field[1:])
455-
else:
456-
includes.add(field)
457-
base_args["fields"] = {"include": includes, "exclude": excludes}
458-
459-
# Remove None values from dict
460-
clean = {}
461-
for k, v in base_args.items():
462-
if v is not None and v != []:
463-
clean[k] = v
525+
clean = clean_search_args(
526+
base_args=base_args,
527+
intersects=intersects,
528+
datetime=datetime,
529+
fields=fields,
530+
sortby=sortby,
531+
filter=filter,
532+
filter_lang=filter_lang,
533+
)
464534

465535
# Do the request
466536
try:
@@ -471,3 +541,60 @@ async def get_search( # noqa: C901
471541
) from e
472542

473543
return await self.post_search(search_request, request=request)
544+
545+
546+
def clean_search_args( # noqa: C901
547+
base_args: dict[str, Any],
548+
intersects: Optional[str] = None,
549+
datetime: Optional[DateTimeType] = None,
550+
fields: Optional[List[str]] = None,
551+
sortby: Optional[str] = None,
552+
filter: Optional[str] = None,
553+
filter_lang: Optional[str] = None,
554+
) -> dict[str, Any]:
555+
"""Clean up search arguments to match format expected by pgstac"""
556+
if filter:
557+
if filter_lang == "cql2-text":
558+
ast = parse_cql2_text(filter)
559+
base_args["filter"] = orjson.loads(to_cql2(ast))
560+
base_args["filter-lang"] = "cql2-json"
561+
562+
if datetime:
563+
base_args["datetime"] = format_datetime_range(datetime)
564+
565+
if intersects:
566+
base_args["intersects"] = orjson.loads(unquote_plus(intersects))
567+
568+
if sortby:
569+
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
570+
sort_param = []
571+
for sort in sortby:
572+
sortparts = re.match(r"^([+-]?)(.*)$", sort)
573+
if sortparts:
574+
sort_param.append(
575+
{
576+
"field": sortparts.group(2).strip(),
577+
"direction": "desc" if sortparts.group(1) == "-" else "asc",
578+
}
579+
)
580+
base_args["sortby"] = sort_param
581+
582+
if fields:
583+
includes = set()
584+
excludes = set()
585+
for field in fields:
586+
if field[0] == "-":
587+
excludes.add(field[1:])
588+
elif field[0] == "+":
589+
includes.add(field[1:])
590+
else:
591+
includes.add(field)
592+
base_args["fields"] = {"include": includes, "exclude": excludes}
593+
594+
# Remove None values from dict
595+
clean = {}
596+
for k, v in base_args.items():
597+
if v is not None and v != []:
598+
clean[k] = v
599+
600+
return clean

0 commit comments

Comments
 (0)