Skip to content

Commit 0d36ee4

Browse files
committed
Implement abstract fetch methods in clients
1 parent 5b776c8 commit 0d36ee4

File tree

2 files changed

+21
-65
lines changed

2 files changed

+21
-65
lines changed

stac_fastapi/pgstac/core.py

Lines changed: 14 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import re
33
from datetime import datetime
44
from typing import Any, Dict, List, Optional, Union
5-
from urllib.parse import urljoin
65

76
import attr
87
import orjson
@@ -13,8 +12,6 @@
1312
from pygeofilter.backends.cql2_json import to_cql2
1413
from pygeofilter.parsers.cql2_text import parse as parse_cql2_text
1514
from pypgstac.hydration import hydrate
16-
from stac_pydantic.links import Relations
17-
from stac_pydantic.shared import MimeTypes
1815
from starlette.requests import Request
1916

2017
from stac_fastapi.pgstac.config import Settings
@@ -23,7 +20,6 @@
2320
from stac_fastapi.pgstac.utils import filter_fields
2421
from stac_fastapi.types.core import AsyncBaseCoreClient
2522
from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError
26-
from stac_fastapi.types.requests import get_base_url
2723
from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection
2824

2925
NumType = Union[float, int]
@@ -33,49 +29,20 @@
3329
class CoreCrudClient(AsyncBaseCoreClient):
3430
"""Client for core endpoints defined by stac."""
3531

36-
async def all_collections(self, **kwargs) -> Collections:
32+
async def fetch_all_collections(self, request: Request) -> Collections:
3733
"""Read all collections from the database."""
38-
request: Request = kwargs["request"]
39-
base_url = get_base_url(request)
4034
pool = request.app.state.readpool
4135

4236
async with pool.acquire() as conn:
43-
collections = await conn.fetchval(
37+
return await conn.fetchval(
4438
"""
4539
SELECT * FROM all_collections();
4640
"""
4741
)
48-
linked_collections: List[Collection] = []
49-
if collections is not None and len(collections) > 0:
50-
for c in collections:
51-
coll = Collection(**c)
52-
coll["links"] = await CollectionLinks(
53-
collection_id=coll["id"], request=request
54-
).get_links(extra_links=coll.get("links"))
55-
56-
linked_collections.append(coll)
57-
58-
links = [
59-
{
60-
"rel": Relations.root.value,
61-
"type": MimeTypes.json,
62-
"href": base_url,
63-
},
64-
{
65-
"rel": Relations.parent.value,
66-
"type": MimeTypes.json,
67-
"href": base_url,
68-
},
69-
{
70-
"rel": Relations.self.value,
71-
"type": MimeTypes.json,
72-
"href": urljoin(base_url, "collections"),
73-
},
74-
]
75-
collection_list = Collections(collections=linked_collections or [], links=links)
76-
return collection_list
77-
78-
async def get_collection(self, collection_id: str, **kwargs) -> Collection:
42+
43+
async def fetch_collection(
44+
self, collection_id: str, request: Request
45+
) -> Collection:
7946
"""Get collection by id.
8047
8148
Called with `GET /collections/{collection_id}`.
@@ -86,9 +53,6 @@ async def get_collection(self, collection_id: str, **kwargs) -> Collection:
8653
Returns:
8754
Collection.
8855
"""
89-
collection: Optional[Dict[str, Any]]
90-
91-
request: Request = kwargs["request"]
9256
pool = request.app.state.readpool
9357
async with pool.acquire() as conn:
9458
q, p = render(
@@ -97,15 +61,7 @@ async def get_collection(self, collection_id: str, **kwargs) -> Collection:
9761
""",
9862
id=collection_id,
9963
)
100-
collection = await conn.fetchval(q, *p)
101-
if collection is None:
102-
raise NotFoundError(f"Collection {collection_id} does not exist.")
103-
104-
collection["links"] = await CollectionLinks(
105-
collection_id=collection_id, request=request
106-
).get_links(extra_links=collection.get("links"))
107-
108-
return Collection(**collection)
64+
return await conn.fetchval(q, *p)
10965

11066
async def _get_base_item(
11167
self, collection_id: str, request: Request
@@ -245,7 +201,7 @@ async def _get_base_item(collection_id: str) -> Dict[str, Any]:
245201
).get_links()
246202
return collection
247203

248-
async def item_collection(
204+
async def handle_collection_items(
249205
self,
250206
collection_id: str,
251207
limit: Optional[int] = None,
@@ -265,7 +221,7 @@ async def item_collection(
265221
An ItemCollection.
266222
"""
267223
# If collection does not exist, NotFoundError wil be raised
268-
await self.get_collection(collection_id, **kwargs)
224+
await self.handle_get_collection(collection_id, **kwargs)
269225

270226
req = self.post_request_model(
271227
collections=[collection_id], limit=limit, token=token
@@ -277,7 +233,7 @@ async def item_collection(
277233
item_collection["links"] = links
278234
return item_collection
279235

280-
async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
236+
async def handle_get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
281237
"""Get item by id.
282238
283239
Called with `GET /collections/{collection_id}/items/{item_id}`.
@@ -290,7 +246,7 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
290246
Item.
291247
"""
292248
# If collection does not exist, NotFoundError wil be raised
293-
await self.get_collection(collection_id, **kwargs)
249+
await self.handle_get_collection(collection_id, **kwargs)
294250

295251
req = self.post_request_model(
296252
ids=[item_id], collections=[collection_id], limit=1
@@ -303,7 +259,7 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
303259

304260
return Item(**item_collection["features"][0])
305261

306-
async def post_search(
262+
async def handle_post_search(
307263
self, search_request: PgstacSearch, **kwargs
308264
) -> ItemCollection:
309265
"""Cross catalog search (POST).
@@ -319,7 +275,7 @@ async def post_search(
319275
item_collection = await self._search_base(search_request, **kwargs)
320276
return ItemCollection(**item_collection)
321277

322-
async def get_search(
278+
async def handle_get_search(
323279
self,
324280
collections: Optional[List[str]] = None,
325281
ids: Optional[List[str]] = None,
@@ -408,4 +364,4 @@ async def get_search(
408364
raise HTTPException(
409365
status_code=400, detail=f"Invalid parameters provided {e}"
410366
)
411-
return await self.post_search(search_request, request=kwargs["request"])
367+
return await self.handle_post_search(search_request, request=kwargs["request"])

stac_fastapi/pgstac/transactions.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
class TransactionsClient(AsyncBaseTransactionsClient):
2525
"""Transactions extension specific CRUD operations."""
2626

27-
async def create_item(
27+
async def handle_create_item(
2828
self, collection_id: str, item: stac_types.Item, **kwargs
2929
) -> Optional[Union[stac_types.Item, Response]]:
3030
"""Create item."""
@@ -45,7 +45,7 @@ async def create_item(
4545
).get_links(extra_links=item.get("links"))
4646
return stac_types.Item(**item)
4747

48-
async def update_item(
48+
async def handle_update_item(
4949
self, collection_id: str, item_id: str, item: stac_types.Item, **kwargs
5050
) -> Optional[Union[stac_types.Item, Response]]:
5151
"""Update item."""
@@ -72,7 +72,7 @@ async def update_item(
7272
).get_links(extra_links=item.get("links"))
7373
return stac_types.Item(**item)
7474

75-
async def create_collection(
75+
async def handle_create_collection(
7676
self, collection: stac_types.Collection, **kwargs
7777
) -> Optional[Union[stac_types.Collection, Response]]:
7878
"""Create collection."""
@@ -85,7 +85,7 @@ async def create_collection(
8585

8686
return stac_types.Collection(**collection)
8787

88-
async def update_collection(
88+
async def handle_update_collection(
8989
self, collection: stac_types.Collection, **kwargs
9090
) -> Optional[Union[stac_types.Collection, Response]]:
9191
"""Update collection."""
@@ -97,7 +97,7 @@ async def update_collection(
9797
).get_links(extra_links=collection.get("links"))
9898
return stac_types.Collection(**collection)
9999

100-
async def delete_item(
100+
async def handle_delete_item(
101101
self, item_id: str, **kwargs
102102
) -> Optional[Union[stac_types.Item, Response]]:
103103
"""Delete item."""
@@ -106,7 +106,7 @@ async def delete_item(
106106
await dbfunc(pool, "delete_item", item_id)
107107
return JSONResponse({"deleted item": item_id})
108108

109-
async def delete_collection(
109+
async def handle_delete_collection(
110110
self, collection_id: str, **kwargs
111111
) -> Optional[Union[stac_types.Collection, Response]]:
112112
"""Delete collection."""
@@ -120,7 +120,7 @@ async def delete_collection(
120120
class BulkTransactionsClient(AsyncBaseBulkTransactionsClient):
121121
"""Postgres bulk transactions."""
122122

123-
async def bulk_item_insert(self, items: Items, **kwargs) -> str:
123+
async def handle_bulk_item_insert(self, items: Items, **kwargs) -> str:
124124
"""Bulk item insertion using pgstac."""
125125
request = kwargs["request"]
126126
pool = request.app.state.writepool

0 commit comments

Comments
 (0)