Skip to content

Commit 810bbb5

Browse files
update fields extension and make sure the app can work without any extension (#123)
* update fields extension and make sure the app can work without any extension * Update stac_fastapi/pgstac/core.py Co-authored-by: Jonathan Healy <[email protected]> --------- Co-authored-by: Jonathan Healy <[email protected]>
1 parent 26f6d91 commit 810bbb5

File tree

6 files changed

+131
-23
lines changed

6 files changed

+131
-23
lines changed

CHANGES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## [Unreleased]
44

5+
- Update stac-fastapi libraries to `~=3.0.0a3`
6+
- make sure the application can work without any extension
7+
58
## [3.0.0a1] - 2024-05-22
69

710
- Update stac-fastapi libraries to `~=3.0.0a1`

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
"orjson",
1111
"pydantic",
1212
"stac_pydantic==3.1.*",
13-
"stac-fastapi.api~=3.0.0a1",
14-
"stac-fastapi.extensions~=3.0.0a1",
15-
"stac-fastapi.types~=3.0.0a1",
13+
"stac-fastapi.api~=3.0.0a3",
14+
"stac-fastapi.extensions~=3.0.0a3",
15+
"stac-fastapi.types~=3.0.0a3",
1616
"asyncpg",
1717
"buildpg",
1818
"brotli_asgi",

stac_fastapi/pgstac/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@
5050
extensions = list(extensions_map.values())
5151

5252
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
53-
53+
get_request_model = create_get_request_model(extensions)
5454
api = StacApi(
5555
settings=settings,
5656
extensions=extensions,
5757
client=CoreCrudClient(post_request_model=post_request_model), # type: ignore
5858
response_class=ORJSONResponse,
59-
search_get_request_model=create_get_request_model(extensions),
59+
search_get_request_model=get_request_model,
6060
search_post_request_model=post_request_model,
6161
)
6262
app = api.app

stac_fastapi/pgstac/core.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Item crud client."""
22

33
import re
4-
from typing import Any, Dict, List, Optional, Union
4+
from typing import Any, Dict, List, Optional, Set, Union
55
from urllib.parse import unquote_plus, urljoin
66

77
import attr
@@ -184,12 +184,9 @@ async def _search_base( # noqa: C901
184184
prev: Optional[str] = items.pop("prev", None)
185185
collection = ItemCollection(**items)
186186

187-
exclude = search_request.fields.exclude
188-
if exclude and len(exclude) == 0:
189-
exclude = None
190-
include = search_request.fields.include
191-
if include and len(include) == 0:
192-
include = None
187+
fields = getattr(search_request, "fields", None)
188+
include: Set[str] = fields.include if fields and fields.include else set()
189+
exclude: Set[str] = fields.exclude if fields and fields.exclude else set()
193190

194191
async def _add_item_links(
195192
feature: Item,
@@ -204,11 +201,7 @@ async def _add_item_links(
204201
collection_id = feature.get("collection") or collection_id
205202
item_id = feature.get("id") or item_id
206203

207-
if (
208-
search_request.fields.exclude is None
209-
or "links" not in search_request.fields.exclude
210-
and all([collection_id, item_id])
211-
):
204+
if not exclude or "links" not in exclude and all([collection_id, item_id]):
212205
feature["links"] = await ItemLinks(
213206
collection_id=collection_id, # type: ignore
214207
item_id=item_id, # type: ignore
@@ -252,6 +245,7 @@ async def _get_base_item(collection_id: str) -> Dict[str, Any]:
252245
next=next,
253246
prev=prev,
254247
).get_links()
248+
255249
return collection
256250

257251
async def item_collection(
@@ -295,14 +289,14 @@ async def item_collection(
295289
if v is not None and v != []:
296290
clean[k] = v
297291

298-
search_request = self.post_request_model(
299-
**clean,
300-
)
292+
search_request = self.post_request_model(**clean)
301293
item_collection = await self._search_base(search_request, request=request)
294+
302295
links = await ItemCollectionLinks(
303296
collection_id=collection_id, request=request
304297
).get_links(extra_links=item_collection["links"])
305298
item_collection["links"] = links
299+
306300
return item_collection
307301

308302
async def get_item(
@@ -355,15 +349,16 @@ async def get_search( # noqa: C901
355349
collections: Optional[List[str]] = None,
356350
ids: Optional[List[str]] = None,
357351
bbox: Optional[BBox] = None,
352+
intersects: Optional[str] = None,
358353
datetime: Optional[DateTimeType] = None,
359354
limit: Optional[int] = None,
355+
# Extensions
360356
query: Optional[str] = None,
361357
token: Optional[str] = None,
362358
fields: Optional[List[str]] = None,
363359
sortby: Optional[str] = None,
364360
filter: Optional[str] = None,
365361
filter_lang: Optional[str] = None,
366-
intersects: Optional[str] = None,
367362
**kwargs,
368363
) -> ItemCollection:
369364
"""Cross catalog search (GET).

stac_fastapi/pgstac/extensions/filter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from buildpg import render
66
from fastapi import Request
7-
from stac_fastapi.types.core import AsyncBaseFiltersClient
7+
from stac_fastapi.extensions.core.filter.client import AsyncBaseFiltersClient
88
from stac_fastapi.types.errors import NotFoundError
99

1010

tests/api/test_api.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from datetime import datetime, timedelta
23
from typing import Any, Callable, Coroutine, Dict, List, Optional, TypeVar
34
from urllib.parse import quote_plus
@@ -6,9 +7,11 @@
67
import pytest
78
from fastapi import Request
89
from httpx import ASGITransport, AsyncClient
10+
from pypgstac.db import PgstacDB
11+
from pypgstac.load import Loader
912
from pystac import Collection, Extent, Item, SpatialExtent, TemporalExtent
1013
from stac_fastapi.api.app import StacApi
11-
from stac_fastapi.api.models import create_post_request_model
14+
from stac_fastapi.api.models import create_get_request_model, create_post_request_model
1215
from stac_fastapi.extensions.core import FieldsExtension, TransactionExtension
1316
from stac_fastapi.types import stac as stac_types
1417

@@ -17,6 +20,9 @@
1720
from stac_fastapi.pgstac.transactions import TransactionsClient
1821
from stac_fastapi.pgstac.types.search import PgstacSearch
1922

23+
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
24+
25+
2026
STAC_CORE_ROUTES = [
2127
"GET /",
2228
"GET /collections",
@@ -669,11 +675,13 @@ async def get_collection(
669675
FieldsExtension(),
670676
]
671677
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
678+
get_request_model = create_get_request_model(extensions)
672679
api = StacApi(
673680
client=Client(post_request_model=post_request_model),
674681
settings=settings,
675682
extensions=extensions,
676683
search_post_request_model=post_request_model,
684+
search_get_request_model=get_request_model,
677685
)
678686
app = api.app
679687
await connect_to_db(app)
@@ -695,3 +703,105 @@ async def get_collection(
695703
assert response.status_code == 200
696704
finally:
697705
await close_db_connection(app)
706+
707+
708+
@pytest.mark.asyncio
709+
@pytest.mark.parametrize("validation", [True, False])
710+
@pytest.mark.parametrize("hydrate", [True, False])
711+
async def test_no_extension(
712+
hydrate, validation, load_test_data, database, pgstac
713+
) -> None:
714+
"""test PgSTAC with no extension."""
715+
connection = f"postgresql://{database.user}:{database.password}@{database.host}:{database.port}/{database.dbname}"
716+
with PgstacDB(dsn=connection) as db:
717+
loader = Loader(db=db)
718+
loader.load_collections(os.path.join(DATA_DIR, "test_collection.json"))
719+
loader.load_items(os.path.join(DATA_DIR, "test_item.json"))
720+
721+
settings = Settings(
722+
postgres_user=database.user,
723+
postgres_pass=database.password,
724+
postgres_host_reader=database.host,
725+
postgres_host_writer=database.host,
726+
postgres_port=database.port,
727+
postgres_dbname=database.dbname,
728+
testing=True,
729+
use_api_hydrate=hydrate,
730+
enable_response_models=validation,
731+
)
732+
extensions = []
733+
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
734+
api = StacApi(
735+
client=CoreCrudClient(post_request_model=post_request_model),
736+
settings=settings,
737+
extensions=extensions,
738+
search_post_request_model=post_request_model,
739+
)
740+
app = api.app
741+
await connect_to_db(app)
742+
try:
743+
async with AsyncClient(transport=ASGITransport(app=app)) as client:
744+
landing = await client.get("http://test/")
745+
assert landing.status_code == 200, landing.text
746+
747+
collection = await client.get("http://test/collections/test-collection")
748+
assert collection.status_code == 200, collection.text
749+
750+
collections = await client.get("http://test/collections")
751+
assert collections.status_code == 200, collections.text
752+
753+
item = await client.get(
754+
"http://test/collections/test-collection/items/test-item"
755+
)
756+
assert item.status_code == 200, item.text
757+
758+
item_collection = await client.get(
759+
"http://test/collections/test-collection/items",
760+
params={"limit": 10},
761+
)
762+
assert item_collection.status_code == 200, item_collection.text
763+
764+
get_search = await client.get(
765+
"http://test/search",
766+
params={
767+
"collections": ["test-collection"],
768+
},
769+
)
770+
assert get_search.status_code == 200, get_search.text
771+
772+
post_search = await client.post(
773+
"http://test/search",
774+
json={
775+
"collections": ["test-collection"],
776+
},
777+
)
778+
assert post_search.status_code == 200, post_search.text
779+
780+
get_search = await client.get(
781+
"http://test/search",
782+
params={
783+
"collections": ["test-collection"],
784+
"fields": "properties.datetime",
785+
},
786+
)
787+
# fields should be ignored
788+
assert get_search.status_code == 200, get_search.text
789+
props = get_search.json()["features"][0]["properties"]
790+
assert len(props) > 1
791+
792+
post_search = await client.post(
793+
"http://test/search",
794+
json={
795+
"collections": ["test-collection"],
796+
"fields": {
797+
"include": ["properties.datetime"],
798+
},
799+
},
800+
)
801+
# fields should be ignored
802+
assert post_search.status_code == 200, post_search.text
803+
props = get_search.json()["features"][0]["properties"]
804+
assert len(props) > 1
805+
806+
finally:
807+
await close_db_connection(app)

0 commit comments

Comments
 (0)