14
14
from pygeofilter .parsers .cql2_text import parse as parse_cql2_text
15
15
from pypgstac .hydration import hydrate
16
16
from stac_fastapi .api .models import JSONResponse
17
- from stac_fastapi .types .core import AsyncBaseCoreClient
17
+ from stac_fastapi .types .core import AsyncBaseCoreClient , Relations
18
18
from stac_fastapi .types .errors import InvalidQueryParameter , NotFoundError
19
19
from stac_fastapi .types .requests import get_base_url
20
20
from stac_fastapi .types .rfc3339 import DateTimeType
21
21
from stac_fastapi .types .stac import Collection , Collections , Item , ItemCollection
22
- from stac_pydantic .links import Relations
23
22
from stac_pydantic .shared import BBox , MimeTypes
24
23
25
24
from stac_fastapi .pgstac .config import Settings
39
38
class CoreCrudClient (AsyncBaseCoreClient ):
40
39
"""Client for core endpoints defined by stac."""
41
40
42
- async def all_collections (self , request : Request , ** kwargs ) -> Collections :
43
- """Read all collections from the database."""
41
+ async def all_collections ( # noqa: C901
42
+ self ,
43
+ request : Request ,
44
+ # Extensions
45
+ bbox : Optional [BBox ] = None ,
46
+ datetime : Optional [DateTimeType ] = None ,
47
+ limit : Optional [int ] = None ,
48
+ query : Optional [str ] = None ,
49
+ token : Optional [str ] = None ,
50
+ fields : Optional [List [str ]] = None ,
51
+ sortby : Optional [str ] = None ,
52
+ filter : Optional [str ] = None ,
53
+ filter_lang : Optional [str ] = None ,
54
+ ** kwargs ,
55
+ ) -> Collections :
56
+ """Cross catalog search (GET).
57
+
58
+ Called with `GET /collections`.
59
+
60
+ Returns:
61
+ Collections which match the search criteria, returns all
62
+ collections by default.
63
+ """
64
+
65
+ # Parse request parameters
66
+ base_args = {
67
+ "bbox" : bbox ,
68
+ "limit" : limit ,
69
+ "token" : token ,
70
+ "query" : orjson .loads (unquote_plus (query )) if query else query ,
71
+ }
72
+
73
+ clean = clean_search_args (
74
+ base_args = base_args ,
75
+ datetime = datetime ,
76
+ fields = fields ,
77
+ sortby = sortby ,
78
+ filter = filter ,
79
+ filter_lang = filter_lang ,
80
+ )
81
+
82
+ # Do the request
83
+ try :
84
+ search_request = self .post_request_model (** clean )
85
+ except ValidationError as e :
86
+ raise HTTPException (
87
+ status_code = 400 , detail = f"Invalid parameters provided { e } "
88
+ ) from e
89
+
90
+ return await self ._collection_search_base (search_request , request = request )
91
+
92
+ async def _collection_search_base ( # noqa: C901
93
+ self ,
94
+ search_request : PgstacSearch ,
95
+ request : Request ,
96
+ ) -> Collections :
97
+ """Cross catalog search (GET).
98
+
99
+ Called with `GET /search`.
100
+
101
+ Args:
102
+ search_request: search request parameters.
103
+
104
+ Returns:
105
+ All collections which match the search criteria.
106
+ """
44
107
base_url = get_base_url (request )
108
+ search_request_json = search_request .model_dump_json (
109
+ exclude_none = True , by_alias = True
110
+ )
111
+
112
+ try :
113
+ async with request .app .state .get_connection (request , "r" ) as conn :
114
+ q , p = render (
115
+ """
116
+ SELECT * FROM collection_search(:req::text::jsonb);
117
+ """ ,
118
+ req = search_request_json ,
119
+ )
120
+ collections_result : Collections = await conn .fetchval (q , * p )
121
+ except InvalidDatetimeFormatError as e :
122
+ raise InvalidQueryParameter (
123
+ f"Datetime parameter { search_request .datetime } is invalid."
124
+ ) from e
125
+
126
+ next : Optional [str ] = None
127
+ prev : Optional [str ] = None
128
+
129
+ if links := collections_result .get ("links" ):
130
+ next = collections_result ["links" ].pop ("next" )
131
+ prev = collections_result ["links" ].pop ("prev" )
45
132
46
- async with request .app .state .get_connection (request , "r" ) as conn :
47
- collections = await conn .fetchval (
48
- """
49
- SELECT * FROM all_collections();
50
- """
51
- )
52
133
linked_collections : List [Collection ] = []
134
+ collections = collections_result ["collections" ]
53
135
if collections is not None and len (collections ) > 0 :
54
136
for c in collections :
55
137
coll = Collection (** c )
@@ -71,25 +153,16 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:
71
153
72
154
linked_collections .append (coll )
73
155
74
- links = [
75
- {
76
- "rel" : Relations .root .value ,
77
- "type" : MimeTypes .json ,
78
- "href" : base_url ,
79
- },
80
- {
81
- "rel" : Relations .parent .value ,
82
- "type" : MimeTypes .json ,
83
- "href" : base_url ,
84
- },
85
- {
86
- "rel" : Relations .self .value ,
87
- "type" : MimeTypes .json ,
88
- "href" : urljoin (base_url , "collections" ),
89
- },
90
- ]
91
- collection_list = Collections (collections = linked_collections or [], links = links )
92
- return collection_list
156
+ links = await PagingLinks (
157
+ request = request ,
158
+ next = next ,
159
+ prev = prev ,
160
+ ).get_links ()
161
+
162
+ return Collections (
163
+ collections = linked_collections or [],
164
+ links = links ,
165
+ )
93
166
94
167
async def get_collection (
95
168
self , collection_id : str , request : Request , ** kwargs
@@ -383,7 +456,7 @@ async def post_search(
383
456
384
457
return ItemCollection (** item_collection )
385
458
386
- async def get_search ( # noqa: C901
459
+ async def get_search (
387
460
self ,
388
461
request : Request ,
389
462
collections : Optional [List [str ]] = None ,
@@ -418,49 +491,15 @@ async def get_search( # noqa: C901
418
491
"query" : orjson .loads (unquote_plus (query )) if query else query ,
419
492
}
420
493
421
- if filter :
422
- if filter_lang == "cql2-text" :
423
- ast = parse_cql2_text (filter )
424
- base_args ["filter" ] = orjson .loads (to_cql2 (ast ))
425
- base_args ["filter-lang" ] = "cql2-json"
426
-
427
- if datetime :
428
- base_args ["datetime" ] = format_datetime_range (datetime )
429
-
430
- if intersects :
431
- base_args ["intersects" ] = orjson .loads (unquote_plus (intersects ))
432
-
433
- if sortby :
434
- # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
435
- sort_param = []
436
- for sort in sortby :
437
- sortparts = re .match (r"^([+-]?)(.*)$" , sort )
438
- if sortparts :
439
- sort_param .append (
440
- {
441
- "field" : sortparts .group (2 ).strip (),
442
- "direction" : "desc" if sortparts .group (1 ) == "-" else "asc" ,
443
- }
444
- )
445
- base_args ["sortby" ] = sort_param
446
-
447
- if fields :
448
- includes = set ()
449
- excludes = set ()
450
- for field in fields :
451
- if field [0 ] == "-" :
452
- excludes .add (field [1 :])
453
- elif field [0 ] == "+" :
454
- includes .add (field [1 :])
455
- else :
456
- includes .add (field )
457
- base_args ["fields" ] = {"include" : includes , "exclude" : excludes }
458
-
459
- # Remove None values from dict
460
- clean = {}
461
- for k , v in base_args .items ():
462
- if v is not None and v != []:
463
- clean [k ] = v
494
+ clean = clean_search_args (
495
+ base_args = base_args ,
496
+ intersects = intersects ,
497
+ datetime = datetime ,
498
+ fields = fields ,
499
+ sortby = sortby ,
500
+ filter = filter ,
501
+ filter_lang = filter_lang ,
502
+ )
464
503
465
504
# Do the request
466
505
try :
@@ -471,3 +510,60 @@ async def get_search( # noqa: C901
471
510
) from e
472
511
473
512
return await self .post_search (search_request , request = request )
513
+
514
+
515
+ def clean_search_args ( # noqa: C901
516
+ base_args : Dict [str , Any ],
517
+ intersects : Optional [str ] = None ,
518
+ datetime : Optional [DateTimeType ] = None ,
519
+ fields : Optional [List [str ]] = None ,
520
+ sortby : Optional [str ] = None ,
521
+ filter : Optional [str ] = None ,
522
+ filter_lang : Optional [str ] = None ,
523
+ ) -> Dict [str , Any ]:
524
+ """Clean up search arguments to match format expected by pgstac"""
525
+ if filter :
526
+ if filter_lang == "cql2-text" :
527
+ ast = parse_cql2_text (filter )
528
+ base_args ["filter" ] = orjson .loads (to_cql2 (ast ))
529
+ base_args ["filter-lang" ] = "cql2-json"
530
+
531
+ if datetime :
532
+ base_args ["datetime" ] = format_datetime_range (datetime )
533
+
534
+ if intersects :
535
+ base_args ["intersects" ] = orjson .loads (unquote_plus (intersects ))
536
+
537
+ if sortby :
538
+ # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
539
+ sort_param = []
540
+ for sort in sortby :
541
+ sortparts = re .match (r"^([+-]?)(.*)$" , sort )
542
+ if sortparts :
543
+ sort_param .append (
544
+ {
545
+ "field" : sortparts .group (2 ).strip (),
546
+ "direction" : "desc" if sortparts .group (1 ) == "-" else "asc" ,
547
+ }
548
+ )
549
+ base_args ["sortby" ] = sort_param
550
+
551
+ if fields :
552
+ includes = set ()
553
+ excludes = set ()
554
+ for field in fields :
555
+ if field [0 ] == "-" :
556
+ excludes .add (field [1 :])
557
+ elif field [0 ] == "+" :
558
+ includes .add (field [1 :])
559
+ else :
560
+ includes .add (field )
561
+ base_args ["fields" ] = {"include" : includes , "exclude" : excludes }
562
+
563
+ # Remove None values from dict
564
+ clean = {}
565
+ for k , v in base_args .items ():
566
+ if v is not None and v != []:
567
+ clean [k ] = v
568
+
569
+ return clean
0 commit comments