-
Notifications
You must be signed in to change notification settings - Fork 36
add collection search extension #136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
930765b
d06480e
6ce1ea7
c597c69
b3065ab
90ef6d4
f560aec
c6b66c5
54962c4
de5c1a4
1dfb484
fd5c48b
bfd94f4
7495544
97adfdc
3bb80f4
a37401c
219eeb1
8d07245
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,5 +9,7 @@ coverage.xml | |
*.log | ||
.git | ||
.envrc | ||
*egg-info | ||
|
||
venv | ||
venv | ||
env |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
"""Item crud client.""" | ||
|
||
import json | ||
import re | ||
from typing import Any, Dict, List, Optional, Set, Union | ||
from urllib.parse import unquote_plus, urljoin | ||
|
@@ -14,12 +15,11 @@ | |
from pygeofilter.parsers.cql2_text import parse as parse_cql2_text | ||
from pypgstac.hydration import hydrate | ||
from stac_fastapi.api.models import JSONResponse | ||
from stac_fastapi.types.core import AsyncBaseCoreClient | ||
from stac_fastapi.types.core import AsyncBaseCoreClient, Relations | ||
from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError | ||
from stac_fastapi.types.requests import get_base_url | ||
from stac_fastapi.types.rfc3339 import DateTimeType | ||
from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection | ||
from stac_pydantic.links import Relations | ||
from stac_pydantic.shared import BBox, MimeTypes | ||
|
||
from stac_fastapi.pgstac.config import Settings | ||
|
@@ -39,17 +39,66 @@ | |
class CoreCrudClient(AsyncBaseCoreClient): | ||
"""Client for core endpoints defined by stac.""" | ||
|
||
async def all_collections(self, request: Request, **kwargs) -> Collections: | ||
"""Read all collections from the database.""" | ||
async def all_collections( # noqa: C901 | ||
self, | ||
request: Request, | ||
# Extensions | ||
bbox: Optional[BBox] = None, | ||
datetime: Optional[DateTimeType] = None, | ||
limit: Optional[int] = None, | ||
Comment on lines
+46
to
+48
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thoughts from discussion: consider adding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't that the same as just calling /collections/:id ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use-case for including an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it makes sense just to be parallel to the items spec (and yes, "ids" plural which is how it works in items) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you add it, I'd recommend to add a separate conformance class for it so that clients actually know whether it's supported or not. PS: ids is not included in collection search as we just inherit from OGC API - Records, which doesn't have it. It's orthogonal to how ids is not part of OGC API - Features for items. ids is a STAC-specific thing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unless someone else feels strongly, I think we can leave |
||
query: Optional[str] = None, | ||
token: Optional[str] = None, | ||
fields: Optional[List[str]] = None, | ||
sortby: Optional[str] = None, | ||
filter: Optional[str] = None, | ||
filter_lang: Optional[str] = None, | ||
**kwargs, | ||
) -> Collections: | ||
"""Cross catalog search (GET). | ||
|
||
Called with `GET /collections`. | ||
|
||
Returns: | ||
Collections which match the search criteria, returns all | ||
collections by default. | ||
""" | ||
base_url = get_base_url(request) | ||
|
||
# Parse request parameters | ||
base_args = { | ||
"bbox": bbox, | ||
"limit": limit, | ||
"token": token, | ||
"query": orjson.loads(unquote_plus(query)) if query else query, | ||
} | ||
|
||
clean_args = clean_search_args( | ||
base_args=base_args, | ||
datetime=datetime, | ||
fields=fields, | ||
sortby=sortby, | ||
filter_query=filter, | ||
filter_lang=filter_lang, | ||
) | ||
|
||
async with request.app.state.get_connection(request, "r") as conn: | ||
hrodmn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
collections = await conn.fetchval( | ||
""" | ||
SELECT * FROM all_collections(); | ||
q, p = render( | ||
""" | ||
SELECT * FROM collection_search(:req::text::jsonb); | ||
""", | ||
req=json.dumps(clean_args), | ||
hrodmn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
collections_result: Collections = await conn.fetchval(q, *p) | ||
|
||
next: Optional[str] = None | ||
prev: Optional[str] = None | ||
|
||
if links := collections_result.get("links"): | ||
next = collections_result["links"].pop("next") | ||
prev = collections_result["links"].pop("prev") | ||
|
||
linked_collections: List[Collection] = [] | ||
collections = collections_result["collections"] | ||
if collections is not None and len(collections) > 0: | ||
for c in collections: | ||
coll = Collection(**c) | ||
|
@@ -71,25 +120,16 @@ async def all_collections(self, request: Request, **kwargs) -> Collections: | |
|
||
linked_collections.append(coll) | ||
|
||
links = [ | ||
{ | ||
"rel": Relations.root.value, | ||
"type": MimeTypes.json, | ||
"href": base_url, | ||
}, | ||
{ | ||
"rel": Relations.parent.value, | ||
"type": MimeTypes.json, | ||
"href": base_url, | ||
}, | ||
{ | ||
"rel": Relations.self.value, | ||
"type": MimeTypes.json, | ||
"href": urljoin(base_url, "collections"), | ||
}, | ||
] | ||
collection_list = Collections(collections=linked_collections or [], links=links) | ||
return collection_list | ||
links = await PagingLinks( | ||
request=request, | ||
next=next, | ||
prev=prev, | ||
).get_links() | ||
|
||
return Collections( | ||
collections=linked_collections or [], | ||
links=links, | ||
) | ||
|
||
async def get_collection( | ||
self, collection_id: str, request: Request, **kwargs | ||
|
@@ -386,7 +426,7 @@ async def post_search( | |
|
||
return ItemCollection(**item_collection) | ||
|
||
async def get_search( # noqa: C901 | ||
async def get_search( | ||
self, | ||
request: Request, | ||
collections: Optional[List[str]] = None, | ||
|
@@ -421,51 +461,15 @@ async def get_search( # noqa: C901 | |
"query": orjson.loads(unquote_plus(query)) if query else query, | ||
} | ||
|
||
if filter: | ||
if filter_lang == "cql2-text": | ||
filter = to_cql2(parse_cql2_text(filter)) | ||
filter_lang = "cql2-json" | ||
|
||
base_args["filter"] = orjson.loads(filter) | ||
base_args["filter-lang"] = filter_lang | ||
|
||
if datetime: | ||
base_args["datetime"] = format_datetime_range(datetime) | ||
|
||
if intersects: | ||
base_args["intersects"] = orjson.loads(unquote_plus(intersects)) | ||
|
||
if sortby: | ||
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form | ||
sort_param = [] | ||
for sort in sortby: | ||
sortparts = re.match(r"^([+-]?)(.*)$", sort) | ||
if sortparts: | ||
sort_param.append( | ||
{ | ||
"field": sortparts.group(2).strip(), | ||
"direction": "desc" if sortparts.group(1) == "-" else "asc", | ||
} | ||
) | ||
base_args["sortby"] = sort_param | ||
|
||
if fields: | ||
includes = set() | ||
excludes = set() | ||
for field in fields: | ||
if field[0] == "-": | ||
excludes.add(field[1:]) | ||
elif field[0] == "+": | ||
includes.add(field[1:]) | ||
else: | ||
includes.add(field) | ||
base_args["fields"] = {"include": includes, "exclude": excludes} | ||
|
||
# Remove None values from dict | ||
clean = {} | ||
for k, v in base_args.items(): | ||
if v is not None and v != []: | ||
clean[k] = v | ||
clean = clean_search_args( | ||
base_args=base_args, | ||
intersects=intersects, | ||
datetime=datetime, | ||
fields=fields, | ||
sortby=sortby, | ||
filter_query=filter, | ||
filter_lang=filter_lang, | ||
) | ||
|
||
# Do the request | ||
try: | ||
|
@@ -476,3 +480,62 @@ async def get_search( # noqa: C901 | |
) from e | ||
|
||
return await self.post_search(search_request, request=request) | ||
|
||
|
||
def clean_search_args( # noqa: C901 | ||
base_args: Dict[str, Any], | ||
intersects: Optional[str] = None, | ||
datetime: Optional[DateTimeType] = None, | ||
fields: Optional[List[str]] = None, | ||
sortby: Optional[str] = None, | ||
filter_query: Optional[str] = None, | ||
filter_lang: Optional[str] = None, | ||
) -> Dict[str, Any]: | ||
"""Clean up search arguments to match format expected by pgstac""" | ||
if filter_query: | ||
if filter_lang == "cql2-text": | ||
filter_query = to_cql2(parse_cql2_text(filter_query)) | ||
filter_lang = "cql2-json" | ||
|
||
base_args["filter"] = orjson.loads(filter_query) | ||
base_args["filter_lang"] = filter_lang | ||
|
||
if datetime: | ||
base_args["datetime"] = format_datetime_range(datetime) | ||
|
||
if intersects: | ||
base_args["intersects"] = orjson.loads(unquote_plus(intersects)) | ||
|
||
if sortby: | ||
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form | ||
sort_param = [] | ||
for sort in sortby: | ||
sortparts = re.match(r"^([+-]?)(.*)$", sort) | ||
if sortparts: | ||
sort_param.append( | ||
{ | ||
"field": sortparts.group(2).strip(), | ||
"direction": "desc" if sortparts.group(1) == "-" else "asc", | ||
} | ||
) | ||
base_args["sortby"] = sort_param | ||
|
||
if fields: | ||
includes = set() | ||
excludes = set() | ||
for field in fields: | ||
if field[0] == "-": | ||
excludes.add(field[1:]) | ||
elif field[0] == "+": | ||
includes.add(field[1:]) | ||
else: | ||
includes.add(field) | ||
base_args["fields"] = {"include": includes, "exclude": excludes} | ||
|
||
# Remove None values from dict | ||
clean = {} | ||
for k, v in base_args.items(): | ||
if v is not None and v != []: | ||
clean[k] = v | ||
|
||
return clean |
Uh oh!
There was an error while loading. Please reload this page.