Skip to content

Commit 9b2050a

Browse files
committed
add collection search extension
1 parent a81e4d7 commit 9b2050a

File tree

8 files changed

+271
-61
lines changed

8 files changed

+271
-61
lines changed

.dockerignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,7 @@ coverage.xml
99
*.log
1010
.git
1111
.envrc
12+
*egg-info
1213

13-
venv
14+
venv
15+
env

CHANGES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## [Unreleased]
44

5+
- Add collection search extension
6+
57
## [3.0.0a4] - 2024-07-10
68

79
- Update stac-fastapi libraries to `~=3.0.0b2`

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
"orjson",
1111
"pydantic",
1212
"stac_pydantic==3.1.*",
13-
"stac-fastapi.api~=3.0.0b2",
14-
"stac-fastapi.extensions~=3.0.0b2",
15-
"stac-fastapi.types~=3.0.0b2",
13+
"stac-fastapi.api~=3.0.0b3",
14+
"stac-fastapi.extensions~=3.0.0b3",
15+
"stac-fastapi.types~=3.0.0b3",
1616
"asyncpg",
1717
"buildpg",
1818
"brotli_asgi",

stac_fastapi/pgstac/app.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
from fastapi.responses import ORJSONResponse
1111
from stac_fastapi.api.app import StacApi
1212
from stac_fastapi.api.models import (
13+
EmptyRequest,
1314
ItemCollectionUri,
1415
create_get_request_model,
1516
create_post_request_model,
1617
create_request_model,
1718
)
1819
from stac_fastapi.extensions.core import (
20+
CollectionSearchExtension,
1921
FieldsExtension,
2022
FilterExtension,
2123
SortExtension,
@@ -47,12 +49,26 @@
4749
"bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()),
4850
}
4951

52+
collections_extensions_map = {
53+
"collection_search": CollectionSearchExtension(),
54+
}
55+
5056
if enabled_extensions := os.getenv("ENABLED_EXTENSIONS"):
57+
_enabled_extensions = enabled_extensions.split(",")
5158
extensions = [
52-
extensions_map[extension_name] for extension_name in enabled_extensions.split(",")
59+
extension
60+
for key, extension in extensions_map.items()
61+
if key in _enabled_extensions
62+
]
63+
collection_extensions = [
64+
extension
65+
for key, extension in collections_extensions_map.items()
66+
if key in _enabled_extensions
5367
]
5468
else:
5569
extensions = list(extensions_map.values())
70+
collection_extensions = list(collections_extensions_map.values())
71+
5672

5773
if any(isinstance(ext, TokenPaginationExtension) for ext in extensions):
5874
items_get_request_model = create_request_model(
@@ -64,17 +80,23 @@
6480
else:
6581
items_get_request_model = ItemCollectionUri
6682

83+
if any(isinstance(ext, CollectionSearchExtension) for ext in collection_extensions):
84+
collections_get_request_model = CollectionSearchExtension().GET
85+
else:
86+
collections_get_request_model = EmptyRequest
87+
6788
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
6889
get_request_model = create_get_request_model(extensions)
6990

7091
api = StacApi(
7192
settings=settings,
72-
extensions=extensions,
93+
extensions=extensions + collection_extensions,
7394
client=CoreCrudClient(post_request_model=post_request_model), # type: ignore
7495
response_class=ORJSONResponse,
7596
items_get_request_model=items_get_request_model,
7697
search_get_request_model=get_request_model,
7798
search_post_request_model=post_request_model,
99+
collections_get_request_model=collections_get_request_model,
78100
)
79101
app = api.app
80102

stac_fastapi/pgstac/core.py

Lines changed: 181 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,113 @@
3939
class CoreCrudClient(AsyncBaseCoreClient):
4040
"""Client for core endpoints defined by stac."""
4141

42-
async def all_collections(self, request: Request, **kwargs) -> Collections:
43-
"""Read all collections from the database."""
42+
async def all_collections( # noqa: C901
43+
self,
44+
request: Request,
45+
bbox: Optional[BBox] = None,
46+
datetime: Optional[DateTimeType] = None,
47+
limit: Optional[int] = None,
48+
# Extensions
49+
query: Optional[str] = None,
50+
token: Optional[str] = None,
51+
fields: Optional[List[str]] = None,
52+
sortby: Optional[str] = None,
53+
filter: Optional[str] = None,
54+
filter_lang: Optional[str] = None,
55+
**kwargs,
56+
) -> Collections:
57+
"""Cross catalog search (GET).
58+
59+
Called with `GET /collections`.
60+
61+
Returns:
62+
Collections which match the search criteria, returns all
63+
collections by default.
64+
"""
65+
query_params = str(request.query_params)
66+
67+
# Kludgy fix because using factory does not allow alias for filter-lang
68+
if filter_lang is None:
69+
match = re.search(r"filter-lang=([a-z0-9-]+)", query_params, re.IGNORECASE)
70+
if match:
71+
filter_lang = match.group(1)
72+
73+
# Parse request parameters
74+
base_args = {
75+
"bbox": bbox,
76+
"limit": limit,
77+
"token": token,
78+
"query": orjson.loads(unquote_plus(query)) if query else query,
79+
}
80+
81+
clean = clean_search_args(
82+
base_args=base_args,
83+
datetime=datetime,
84+
fields=fields,
85+
sortby=sortby,
86+
filter=filter,
87+
filter_lang=filter_lang,
88+
)
89+
90+
# Do the request
91+
try:
92+
search_request = self.post_request_model(**clean)
93+
except ValidationError as e:
94+
raise HTTPException(
95+
status_code=400, detail=f"Invalid parameters provided {e}"
96+
) from e
97+
98+
return await self._collection_search_base(search_request, request=request)
99+
100+
async def _collection_search_base( # noqa: C901
101+
self,
102+
search_request: PgstacSearch,
103+
request: Request,
104+
) -> Collections:
105+
"""Cross catalog search (POST).
106+
107+
Called with `POST /search`.
108+
109+
Args:
110+
search_request: search request parameters.
111+
112+
Returns:
113+
All collections which match the search criteria.
114+
"""
115+
44116
base_url = get_base_url(request)
45117

46-
async with request.app.state.get_connection(request, "r") as conn:
47-
collections = await conn.fetchval(
48-
"""
49-
SELECT * FROM all_collections();
50-
"""
51-
)
118+
settings: Settings = request.app.state.settings
119+
120+
if search_request.datetime:
121+
search_request.datetime = format_datetime_range(search_request.datetime)
122+
123+
search_request.conf = search_request.conf or {}
124+
search_request.conf["nohydrate"] = settings.use_api_hydrate
125+
126+
search_request_json = search_request.model_dump_json(
127+
exclude_none=True, by_alias=True
128+
)
129+
130+
try:
131+
async with request.app.state.get_connection(request, "r") as conn:
132+
q, p = render(
133+
"""
134+
SELECT * FROM collection_search(:req::text::jsonb);
135+
""",
136+
req=search_request_json,
137+
)
138+
collections_result: Collections = await conn.fetchval(q, *p)
139+
except InvalidDatetimeFormatError as e:
140+
raise InvalidQueryParameter(
141+
f"Datetime parameter {search_request.datetime} is invalid."
142+
) from e
143+
144+
# next: Optional[str] = collections_result["links"].pop("next")
145+
# prev: Optional[str] = collections_result["links"].pop("prev")
146+
52147
linked_collections: List[Collection] = []
148+
collections = collections_result["collections"]
53149
if collections is not None and len(collections) > 0:
54150
for c in collections:
55151
coll = Collection(**c)
@@ -59,6 +155,12 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:
59155

60156
linked_collections.append(coll)
61157

158+
# paging_links = await PagingLinks(
159+
# request=request,
160+
# next=next,
161+
# prev=prev,
162+
# ).get_links()
163+
62164
links = [
63165
{
64166
"rel": Relations.root.value,
@@ -76,8 +178,10 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:
76178
"href": urljoin(base_url, "collections"),
77179
},
78180
]
79-
collection_list = Collections(collections=linked_collections or [], links=links)
80-
return collection_list
181+
return Collections(
182+
collections=linked_collections or [],
183+
links=links, # + paging_links
184+
)
81185

82186
async def get_collection(
83187
self, collection_id: str, request: Request, **kwargs
@@ -352,7 +456,7 @@ async def post_search(
352456

353457
return ItemCollection(**item_collection)
354458

355-
async def get_search( # noqa: C901
459+
async def get_search(
356460
self,
357461
request: Request,
358462
collections: Optional[List[str]] = None,
@@ -395,49 +499,15 @@ async def get_search( # noqa: C901
395499
"query": orjson.loads(unquote_plus(query)) if query else query,
396500
}
397501

398-
if filter:
399-
if filter_lang == "cql2-text":
400-
ast = parse_cql2_text(filter)
401-
base_args["filter"] = orjson.loads(to_cql2(ast))
402-
base_args["filter-lang"] = "cql2-json"
403-
404-
if datetime:
405-
base_args["datetime"] = format_datetime_range(datetime)
406-
407-
if intersects:
408-
base_args["intersects"] = orjson.loads(unquote_plus(intersects))
409-
410-
if sortby:
411-
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
412-
sort_param = []
413-
for sort in sortby:
414-
sortparts = re.match(r"^([+-]?)(.*)$", sort)
415-
if sortparts:
416-
sort_param.append(
417-
{
418-
"field": sortparts.group(2).strip(),
419-
"direction": "desc" if sortparts.group(1) == "-" else "asc",
420-
}
421-
)
422-
base_args["sortby"] = sort_param
423-
424-
if fields:
425-
includes = set()
426-
excludes = set()
427-
for field in fields:
428-
if field[0] == "-":
429-
excludes.add(field[1:])
430-
elif field[0] == "+":
431-
includes.add(field[1:])
432-
else:
433-
includes.add(field)
434-
base_args["fields"] = {"include": includes, "exclude": excludes}
435-
436-
# Remove None values from dict
437-
clean = {}
438-
for k, v in base_args.items():
439-
if v is not None and v != []:
440-
clean[k] = v
502+
clean = clean_search_args(
503+
base_args=base_args,
504+
intersects=intersects,
505+
datetime=datetime,
506+
fields=fields,
507+
sortby=sortby,
508+
filter=filter,
509+
filter_lang=filter_lang,
510+
)
441511

442512
# Do the request
443513
try:
@@ -448,3 +518,60 @@ async def get_search( # noqa: C901
448518
) from e
449519

450520
return await self.post_search(search_request, request=request)
521+
522+
523+
def clean_search_args( # noqa: C901
524+
base_args: dict[str, Any],
525+
intersects: Optional[str] = None,
526+
datetime: Optional[DateTimeType] = None,
527+
fields: Optional[List[str]] = None,
528+
sortby: Optional[str] = None,
529+
filter: Optional[str] = None,
530+
filter_lang: Optional[str] = None,
531+
) -> dict[str, Any]:
532+
"""Clean up search arguments to match format expected by pgstac"""
533+
if filter:
534+
if filter_lang == "cql2-text":
535+
ast = parse_cql2_text(filter)
536+
base_args["filter"] = orjson.loads(to_cql2(ast))
537+
base_args["filter-lang"] = "cql2-json"
538+
539+
if datetime:
540+
base_args["datetime"] = format_datetime_range(datetime)
541+
542+
if intersects:
543+
base_args["intersects"] = orjson.loads(unquote_plus(intersects))
544+
545+
if sortby:
546+
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
547+
sort_param = []
548+
for sort in sortby:
549+
sortparts = re.match(r"^([+-]?)(.*)$", sort)
550+
if sortparts:
551+
sort_param.append(
552+
{
553+
"field": sortparts.group(2).strip(),
554+
"direction": "desc" if sortparts.group(1) == "-" else "asc",
555+
}
556+
)
557+
base_args["sortby"] = sort_param
558+
559+
if fields:
560+
includes = set()
561+
excludes = set()
562+
for field in fields:
563+
if field[0] == "-":
564+
excludes.add(field[1:])
565+
elif field[0] == "+":
566+
includes.add(field[1:])
567+
else:
568+
includes.add(field)
569+
base_args["fields"] = {"include": includes, "exclude": excludes}
570+
571+
# Remove None values from dict
572+
clean = {}
573+
for k, v in base_args.items():
574+
if v is not None and v != []:
575+
clean[k] = v
576+
577+
return clean

0 commit comments

Comments
 (0)