39
39
class CoreCrudClient (AsyncBaseCoreClient ):
40
40
"""Client for core endpoints defined by stac."""
41
41
42
- async def all_collections (self , request : Request , ** kwargs ) -> Collections :
43
- """Read all collections from the database."""
42
+ async def all_collections ( # noqa: C901
43
+ self ,
44
+ request : Request ,
45
+ bbox : Optional [BBox ] = None ,
46
+ datetime : Optional [DateTimeType ] = None ,
47
+ limit : Optional [int ] = None ,
48
+ # Extensions
49
+ query : Optional [str ] = None ,
50
+ token : Optional [str ] = None ,
51
+ fields : Optional [List [str ]] = None ,
52
+ sortby : Optional [str ] = None ,
53
+ filter : Optional [str ] = None ,
54
+ filter_lang : Optional [str ] = None ,
55
+ ** kwargs ,
56
+ ) -> Collections :
57
+ """Cross catalog search (GET).
58
+
59
+ Called with `GET /collections`.
60
+
61
+ Returns:
62
+ Collections which match the search criteria, returns all
63
+ collections by default.
64
+ """
65
+ query_params = str (request .query_params )
66
+
67
+ # Kludgy fix because using factory does not allow alias for filter-lang
68
+ if filter_lang is None :
69
+ match = re .search (r"filter-lang=([a-z0-9-]+)" , query_params , re .IGNORECASE )
70
+ if match :
71
+ filter_lang = match .group (1 )
72
+
73
+ # Parse request parameters
74
+ base_args = {
75
+ "bbox" : bbox ,
76
+ "limit" : limit ,
77
+ "token" : token ,
78
+ "query" : orjson .loads (unquote_plus (query )) if query else query ,
79
+ }
80
+
81
+ clean = clean_search_args (
82
+ base_args = base_args ,
83
+ datetime = datetime ,
84
+ fields = fields ,
85
+ sortby = sortby ,
86
+ filter = filter ,
87
+ filter_lang = filter_lang ,
88
+ )
89
+
90
+ # Do the request
91
+ try :
92
+ search_request = self .post_request_model (** clean )
93
+ except ValidationError as e :
94
+ raise HTTPException (
95
+ status_code = 400 , detail = f"Invalid parameters provided { e } "
96
+ ) from e
97
+
98
+ return await self ._collection_search_base (search_request , request = request )
99
+
100
+ async def _collection_search_base ( # noqa: C901
101
+ self ,
102
+ search_request : PgstacSearch ,
103
+ request : Request ,
104
+ ) -> Collections :
105
+ """Cross catalog search (POST).
106
+
107
+ Called with `POST /search`.
108
+
109
+ Args:
110
+ search_request: search request parameters.
111
+
112
+ Returns:
113
+ All collections which match the search criteria.
114
+ """
115
+
44
116
base_url = get_base_url (request )
45
117
46
- async with request .app .state .get_connection (request , "r" ) as conn :
47
- collections = await conn .fetchval (
48
- """
49
- SELECT * FROM all_collections();
50
- """
51
- )
118
+ settings : Settings = request .app .state .settings
119
+
120
+ if search_request .datetime :
121
+ search_request .datetime = format_datetime_range (search_request .datetime )
122
+
123
+ search_request .conf = search_request .conf or {}
124
+ search_request .conf ["nohydrate" ] = settings .use_api_hydrate
125
+
126
+ search_request_json = search_request .model_dump_json (
127
+ exclude_none = True , by_alias = True
128
+ )
129
+
130
+ try :
131
+ async with request .app .state .get_connection (request , "r" ) as conn :
132
+ q , p = render (
133
+ """
134
+ SELECT * FROM collection_search(:req::text::jsonb);
135
+ """ ,
136
+ req = search_request_json ,
137
+ )
138
+ collections_result : Collections = await conn .fetchval (q , * p )
139
+ except InvalidDatetimeFormatError as e :
140
+ raise InvalidQueryParameter (
141
+ f"Datetime parameter { search_request .datetime } is invalid."
142
+ ) from e
143
+
144
+ # next: Optional[str] = collections_result["links"].pop("next")
145
+ # prev: Optional[str] = collections_result["links"].pop("prev")
146
+
52
147
linked_collections : List [Collection ] = []
148
+ collections = collections_result ["collections" ]
53
149
if collections is not None and len (collections ) > 0 :
54
150
for c in collections :
55
151
coll = Collection (** c )
@@ -59,6 +155,12 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:
59
155
60
156
linked_collections .append (coll )
61
157
158
+ # paging_links = await PagingLinks(
159
+ # request=request,
160
+ # next=next,
161
+ # prev=prev,
162
+ # ).get_links()
163
+
62
164
links = [
63
165
{
64
166
"rel" : Relations .root .value ,
@@ -76,8 +178,10 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:
76
178
"href" : urljoin (base_url , "collections" ),
77
179
},
78
180
]
79
- collection_list = Collections (collections = linked_collections or [], links = links )
80
- return collection_list
181
+ return Collections (
182
+ collections = linked_collections or [],
183
+ links = links , # + paging_links
184
+ )
81
185
82
186
async def get_collection (
83
187
self , collection_id : str , request : Request , ** kwargs
@@ -352,7 +456,7 @@ async def post_search(
352
456
353
457
return ItemCollection (** item_collection )
354
458
355
- async def get_search ( # noqa: C901
459
+ async def get_search (
356
460
self ,
357
461
request : Request ,
358
462
collections : Optional [List [str ]] = None ,
@@ -395,49 +499,15 @@ async def get_search( # noqa: C901
395
499
"query" : orjson .loads (unquote_plus (query )) if query else query ,
396
500
}
397
501
398
- if filter :
399
- if filter_lang == "cql2-text" :
400
- ast = parse_cql2_text (filter )
401
- base_args ["filter" ] = orjson .loads (to_cql2 (ast ))
402
- base_args ["filter-lang" ] = "cql2-json"
403
-
404
- if datetime :
405
- base_args ["datetime" ] = format_datetime_range (datetime )
406
-
407
- if intersects :
408
- base_args ["intersects" ] = orjson .loads (unquote_plus (intersects ))
409
-
410
- if sortby :
411
- # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
412
- sort_param = []
413
- for sort in sortby :
414
- sortparts = re .match (r"^([+-]?)(.*)$" , sort )
415
- if sortparts :
416
- sort_param .append (
417
- {
418
- "field" : sortparts .group (2 ).strip (),
419
- "direction" : "desc" if sortparts .group (1 ) == "-" else "asc" ,
420
- }
421
- )
422
- base_args ["sortby" ] = sort_param
423
-
424
- if fields :
425
- includes = set ()
426
- excludes = set ()
427
- for field in fields :
428
- if field [0 ] == "-" :
429
- excludes .add (field [1 :])
430
- elif field [0 ] == "+" :
431
- includes .add (field [1 :])
432
- else :
433
- includes .add (field )
434
- base_args ["fields" ] = {"include" : includes , "exclude" : excludes }
435
-
436
- # Remove None values from dict
437
- clean = {}
438
- for k , v in base_args .items ():
439
- if v is not None and v != []:
440
- clean [k ] = v
502
+ clean = clean_search_args (
503
+ base_args = base_args ,
504
+ intersects = intersects ,
505
+ datetime = datetime ,
506
+ fields = fields ,
507
+ sortby = sortby ,
508
+ filter = filter ,
509
+ filter_lang = filter_lang ,
510
+ )
441
511
442
512
# Do the request
443
513
try :
@@ -448,3 +518,60 @@ async def get_search( # noqa: C901
448
518
) from e
449
519
450
520
return await self .post_search (search_request , request = request )
521
+
522
+
523
+ def clean_search_args ( # noqa: C901
524
+ base_args : dict [str , Any ],
525
+ intersects : Optional [str ] = None ,
526
+ datetime : Optional [DateTimeType ] = None ,
527
+ fields : Optional [List [str ]] = None ,
528
+ sortby : Optional [str ] = None ,
529
+ filter : Optional [str ] = None ,
530
+ filter_lang : Optional [str ] = None ,
531
+ ) -> dict [str , Any ]:
532
+ """Clean up search arguments to match format expected by pgstac"""
533
+ if filter :
534
+ if filter_lang == "cql2-text" :
535
+ ast = parse_cql2_text (filter )
536
+ base_args ["filter" ] = orjson .loads (to_cql2 (ast ))
537
+ base_args ["filter-lang" ] = "cql2-json"
538
+
539
+ if datetime :
540
+ base_args ["datetime" ] = format_datetime_range (datetime )
541
+
542
+ if intersects :
543
+ base_args ["intersects" ] = orjson .loads (unquote_plus (intersects ))
544
+
545
+ if sortby :
546
+ # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
547
+ sort_param = []
548
+ for sort in sortby :
549
+ sortparts = re .match (r"^([+-]?)(.*)$" , sort )
550
+ if sortparts :
551
+ sort_param .append (
552
+ {
553
+ "field" : sortparts .group (2 ).strip (),
554
+ "direction" : "desc" if sortparts .group (1 ) == "-" else "asc" ,
555
+ }
556
+ )
557
+ base_args ["sortby" ] = sort_param
558
+
559
+ if fields :
560
+ includes = set ()
561
+ excludes = set ()
562
+ for field in fields :
563
+ if field [0 ] == "-" :
564
+ excludes .add (field [1 :])
565
+ elif field [0 ] == "+" :
566
+ includes .add (field [1 :])
567
+ else :
568
+ includes .add (field )
569
+ base_args ["fields" ] = {"include" : includes , "exclude" : excludes }
570
+
571
+ # Remove None values from dict
572
+ clean = {}
573
+ for k , v in base_args .items ():
574
+ if v is not None and v != []:
575
+ clean [k ] = v
576
+
577
+ return clean
0 commit comments