Skip to content

Commit 7495544

Browse files
committed
recycle collections_get_request_model in client
1 parent bfd94f4 commit 7495544

File tree

4 files changed

+19
-21
lines changed

4 files changed

+19
-21
lines changed

stac_fastapi/pgstac/app.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,12 @@
8383
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
8484
get_request_model = create_get_request_model(extensions)
8585

86-
# will only use parameters defined in collections_get_request_model
87-
collection_search_model = create_post_request_model(extensions, base_model=PgstacSearch)
88-
8986
api = StacApi(
9087
settings=settings,
9188
extensions=extensions + [collection_search_extension],
9289
client=CoreCrudClient(
9390
post_request_model=post_request_model, # type: ignore
94-
collection_request_model=collection_search_model, # type: ignore
91+
collections_get_request_model=collections_get_request_model, # type: ignore
9592
),
9693
response_class=ORJSONResponse,
9794
items_get_request_model=items_get_request_model,

stac_fastapi/pgstac/core.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Item crud client."""
22

3+
import json
34
import re
45
from typing import Any, Dict, List, Optional, Set, Union
56
from urllib.parse import unquote_plus, urljoin
@@ -13,7 +14,7 @@
1314
from pygeofilter.backends.cql2_json import to_cql2
1415
from pygeofilter.parsers.cql2_text import parse as parse_cql2_text
1516
from pypgstac.hydration import hydrate
16-
from stac_fastapi.api.models import JSONResponse
17+
from stac_fastapi.api.models import APIRequest, EmptyRequest, JSONResponse
1718
from stac_fastapi.types.core import AsyncBaseCoreClient, Relations
1819
from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError
1920
from stac_fastapi.types.requests import get_base_url
@@ -38,7 +39,7 @@
3839
class CoreCrudClient(AsyncBaseCoreClient):
3940
"""Client for core endpoints defined by stac."""
4041

41-
collection_request_model = attr.ib(default=PgstacSearch)
42+
collections_get_request_model: APIRequest = attr.ib(default=EmptyRequest)
4243

4344
async def all_collections( # noqa: C901
4445
self,
@@ -83,7 +84,8 @@ async def all_collections( # noqa: C901
8384

8485
# Do the request
8586
try:
86-
search_request = self.collection_request_model(**clean)
87+
search_request = self.collections_get_request_model(**clean)
88+
print(search_request)
8789
except ValidationError as e:
8890
raise HTTPException(
8991
status_code=400, detail=f"Invalid parameters provided {e}"
@@ -93,7 +95,7 @@ async def all_collections( # noqa: C901
9395

9496
async def _collection_search_base( # noqa: C901
9597
self,
96-
search_request: PgstacSearch,
98+
search_request: APIRequest,
9799
request: Request,
98100
) -> Collections:
99101
"""Cross catalog search (GET).
@@ -107,8 +109,12 @@ async def _collection_search_base( # noqa: C901
107109
All collections which match the search criteria.
108110
"""
109111
base_url = get_base_url(request)
110-
search_request_json = search_request.model_dump_json(
111-
exclude_none=True, by_alias=True
112+
search_request_json = json.dumps(
113+
{
114+
key: value
115+
for key, value in search_request.__dict__.items()
116+
if value is not None
117+
}
112118
)
113119

114120
try:
@@ -533,7 +539,7 @@ def clean_search_args( # noqa: C901
533539
filter_lang = "cql2-json"
534540

535541
base_args["filter"] = orjson.loads(filter_query)
536-
base_args["filter-lang"] = filter_lang
542+
base_args["filter_lang"] = filter_lang
537543

538544
if datetime:
539545
base_args["datetime"] = format_datetime_range(datetime)

tests/api/test_api.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -730,23 +730,21 @@ async def get_collection(
730730
]
731731
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
732732
get_request_model = create_get_request_model(extensions)
733-
collection_search_model = create_post_request_model(
734-
extensions, base_model=PgstacSearch
735-
)
733+
736734
collection_search_extension = CollectionSearchExtension.from_extensions(
737735
extensions=extensions
738736
)
739-
collections_get_request_model = collection_search_extension.GET
737+
740738
api = StacApi(
741739
client=Client(
742740
post_request_model=post_request_model,
743-
collection_request_model=collection_search_model,
741+
collections_get_request_model=collection_search_extension.GET,
744742
),
745743
settings=settings,
746744
extensions=extensions,
747745
search_post_request_model=post_request_model,
748746
search_get_request_model=get_request_model,
749-
collections_get_request_model=collections_get_request_model,
747+
collections_get_request_model=collection_search_extension.GET,
750748
)
751749
app = api.app
752750
await connect_to_db(app)

tests/conftest.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,16 +151,13 @@ def api_client(request, database):
151151
)
152152

153153
collections_get_request_model = collection_search_extension.GET
154-
collection_search_model = create_post_request_model(
155-
extensions, base_model=PgstacSearch
156-
)
157154

158155
api = StacApi(
159156
settings=api_settings,
160157
extensions=extensions + [collection_search_extension],
161158
client=CoreCrudClient(
162159
post_request_model=search_post_request_model,
163-
collection_request_model=collection_search_model,
160+
collections_get_request_model=collections_get_request_model,
164161
),
165162
items_get_request_model=items_get_request_model,
166163
search_get_request_model=search_get_request_model,

0 commit comments

Comments
 (0)