39
39
class CoreCrudClient (AsyncBaseCoreClient ):
40
40
"""Client for core endpoints defined by stac."""
41
41
42
- async def all_collections (self , request : Request , ** kwargs ) -> Collections :
43
- """Read all collections from the database."""
42
+ async def all_collections ( # noqa: C901
43
+ self ,
44
+ request : Request ,
45
+ bbox : Optional [BBox ] = None ,
46
+ datetime : Optional [DateTimeType ] = None ,
47
+ limit : Optional [int ] = None ,
48
+ # Extensions
49
+ query : Optional [str ] = None ,
50
+ token : Optional [str ] = None ,
51
+ fields : Optional [List [str ]] = None ,
52
+ sortby : Optional [str ] = None ,
53
+ filter : Optional [str ] = None ,
54
+ filter_lang : Optional [str ] = None ,
55
+ ** kwargs ,
56
+ ) -> Collections :
57
+ """Cross catalog search (GET).
58
+
59
+ Called with `GET /collections`.
60
+
61
+ Returns:
62
+ Collections which match the search criteria, returns all
63
+ collections by default.
64
+ """
65
+ query_params = str (request .query_params )
66
+
67
+ # Kludgy fix because using factory does not allow alias for filter-lang
68
+ if filter_lang is None :
69
+ match = re .search (r"filter-lang=([a-z0-9-]+)" , query_params , re .IGNORECASE )
70
+ if match :
71
+ filter_lang = match .group (1 )
72
+
73
+ # Parse request parameters
74
+ base_args = {
75
+ "bbox" : bbox ,
76
+ "limit" : limit ,
77
+ "token" : token ,
78
+ "query" : orjson .loads (unquote_plus (query )) if query else query ,
79
+ }
80
+
81
+ clean = clean_search_args (
82
+ base_args = base_args ,
83
+ datetime = datetime ,
84
+ fields = fields ,
85
+ sortby = sortby ,
86
+ filter = filter ,
87
+ filter_lang = filter_lang ,
88
+ )
89
+
90
+ # Do the request
91
+ try :
92
+ search_request = self .post_request_model (** clean )
93
+ except ValidationError as e :
94
+ raise HTTPException (
95
+ status_code = 400 , detail = f"Invalid parameters provided { e } "
96
+ ) from e
97
+
98
+ return await self ._collection_search_base (search_request , request = request )
99
+
100
+ async def _collection_search_base ( # noqa: C901
101
+ self ,
102
+ search_request : PgstacSearch ,
103
+ request : Request ,
104
+ ) -> Collections :
105
+ """Cross catalog search (POST).
106
+
107
+ Called with `POST /search`.
108
+
109
+ Args:
110
+ search_request: search request parameters.
111
+
112
+ Returns:
113
+ All collections which match the search criteria.
114
+ """
115
+
44
116
base_url = get_base_url (request )
45
117
46
- async with request .app .state .get_connection (request , "r" ) as conn :
47
- collections = await conn .fetchval (
48
- """
49
- SELECT * FROM all_collections();
50
- """
51
- )
118
+ settings : Settings = request .app .state .settings
119
+
120
+ if search_request .datetime :
121
+ search_request .datetime = format_datetime_range (search_request .datetime )
122
+
123
+ search_request .conf = search_request .conf or {}
124
+ search_request .conf ["nohydrate" ] = settings .use_api_hydrate
125
+
126
+ search_request_json = search_request .model_dump_json (
127
+ exclude_none = True , by_alias = True
128
+ )
129
+
130
+ try :
131
+ async with request .app .state .get_connection (request , "r" ) as conn :
132
+ q , p = render (
133
+ """
134
+ SELECT * FROM collection_search(:req::text::jsonb);
135
+ """ ,
136
+ req = search_request_json ,
137
+ )
138
+ collections_result : Collections = await conn .fetchval (q , * p )
139
+ except InvalidDatetimeFormatError as e :
140
+ raise InvalidQueryParameter (
141
+ f"Datetime parameter { search_request .datetime } is invalid."
142
+ ) from e
143
+
144
+ # next: Optional[str] = collections_result["links"].pop("next")
145
+ # prev: Optional[str] = collections_result["links"].pop("prev")
146
+
52
147
linked_collections : List [Collection ] = []
148
+ collections = collections_result ["collections" ]
53
149
if collections is not None and len (collections ) > 0 :
54
150
for c in collections :
55
151
coll = Collection (** c )
@@ -71,6 +167,12 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:
71
167
72
168
linked_collections .append (coll )
73
169
170
+ # paging_links = await PagingLinks(
171
+ # request=request,
172
+ # next=next,
173
+ # prev=prev,
174
+ # ).get_links()
175
+
74
176
links = [
75
177
{
76
178
"rel" : Relations .root .value ,
@@ -88,8 +190,10 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:
88
190
"href" : urljoin (base_url , "collections" ),
89
191
},
90
192
]
91
- collection_list = Collections (collections = linked_collections or [], links = links )
92
- return collection_list
193
+ return Collections (
194
+ collections = linked_collections or [],
195
+ links = links , # + paging_links
196
+ )
93
197
94
198
async def get_collection (
95
199
self , collection_id : str , request : Request , ** kwargs
@@ -383,7 +487,7 @@ async def post_search(
383
487
384
488
return ItemCollection (** item_collection )
385
489
386
- async def get_search ( # noqa: C901
490
+ async def get_search (
387
491
self ,
388
492
request : Request ,
389
493
collections : Optional [List [str ]] = None ,
@@ -418,49 +522,15 @@ async def get_search( # noqa: C901
418
522
"query" : orjson .loads (unquote_plus (query )) if query else query ,
419
523
}
420
524
421
- if filter :
422
- if filter_lang == "cql2-text" :
423
- ast = parse_cql2_text (filter )
424
- base_args ["filter" ] = orjson .loads (to_cql2 (ast ))
425
- base_args ["filter-lang" ] = "cql2-json"
426
-
427
- if datetime :
428
- base_args ["datetime" ] = format_datetime_range (datetime )
429
-
430
- if intersects :
431
- base_args ["intersects" ] = orjson .loads (unquote_plus (intersects ))
432
-
433
- if sortby :
434
- # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
435
- sort_param = []
436
- for sort in sortby :
437
- sortparts = re .match (r"^([+-]?)(.*)$" , sort )
438
- if sortparts :
439
- sort_param .append (
440
- {
441
- "field" : sortparts .group (2 ).strip (),
442
- "direction" : "desc" if sortparts .group (1 ) == "-" else "asc" ,
443
- }
444
- )
445
- base_args ["sortby" ] = sort_param
446
-
447
- if fields :
448
- includes = set ()
449
- excludes = set ()
450
- for field in fields :
451
- if field [0 ] == "-" :
452
- excludes .add (field [1 :])
453
- elif field [0 ] == "+" :
454
- includes .add (field [1 :])
455
- else :
456
- includes .add (field )
457
- base_args ["fields" ] = {"include" : includes , "exclude" : excludes }
458
-
459
- # Remove None values from dict
460
- clean = {}
461
- for k , v in base_args .items ():
462
- if v is not None and v != []:
463
- clean [k ] = v
525
+ clean = clean_search_args (
526
+ base_args = base_args ,
527
+ intersects = intersects ,
528
+ datetime = datetime ,
529
+ fields = fields ,
530
+ sortby = sortby ,
531
+ filter = filter ,
532
+ filter_lang = filter_lang ,
533
+ )
464
534
465
535
# Do the request
466
536
try :
@@ -471,3 +541,60 @@ async def get_search( # noqa: C901
471
541
) from e
472
542
473
543
return await self .post_search (search_request , request = request )
544
+
545
+
546
+ def clean_search_args ( # noqa: C901
547
+ base_args : dict [str , Any ],
548
+ intersects : Optional [str ] = None ,
549
+ datetime : Optional [DateTimeType ] = None ,
550
+ fields : Optional [List [str ]] = None ,
551
+ sortby : Optional [str ] = None ,
552
+ filter : Optional [str ] = None ,
553
+ filter_lang : Optional [str ] = None ,
554
+ ) -> dict [str , Any ]:
555
+ """Clean up search arguments to match format expected by pgstac"""
556
+ if filter :
557
+ if filter_lang == "cql2-text" :
558
+ ast = parse_cql2_text (filter )
559
+ base_args ["filter" ] = orjson .loads (to_cql2 (ast ))
560
+ base_args ["filter-lang" ] = "cql2-json"
561
+
562
+ if datetime :
563
+ base_args ["datetime" ] = format_datetime_range (datetime )
564
+
565
+ if intersects :
566
+ base_args ["intersects" ] = orjson .loads (unquote_plus (intersects ))
567
+
568
+ if sortby :
569
+ # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
570
+ sort_param = []
571
+ for sort in sortby :
572
+ sortparts = re .match (r"^([+-]?)(.*)$" , sort )
573
+ if sortparts :
574
+ sort_param .append (
575
+ {
576
+ "field" : sortparts .group (2 ).strip (),
577
+ "direction" : "desc" if sortparts .group (1 ) == "-" else "asc" ,
578
+ }
579
+ )
580
+ base_args ["sortby" ] = sort_param
581
+
582
+ if fields :
583
+ includes = set ()
584
+ excludes = set ()
585
+ for field in fields :
586
+ if field [0 ] == "-" :
587
+ excludes .add (field [1 :])
588
+ elif field [0 ] == "+" :
589
+ includes .add (field [1 :])
590
+ else :
591
+ includes .add (field )
592
+ base_args ["fields" ] = {"include" : includes , "exclude" : excludes }
593
+
594
+ # Remove None values from dict
595
+ clean = {}
596
+ for k , v in base_args .items ():
597
+ if v is not None and v != []:
598
+ clean [k ] = v
599
+
600
+ return clean
0 commit comments