Skip to content

Commit 226bdc0

Browse files
authored
Added bulk item inserts for pgstac implementation (#411)
* Added bulk item inserts for pgstac implementation * Updated changelog
1 parent 162a1a2 commit 226bdc0

File tree

6 files changed

+90
-6
lines changed

6 files changed

+90
-6
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
* Add hook to allow adding dependencies to routes. ([#295](https://github.com/stac-utils/stac-fastapi/pull/295))
88
* Ability to POST an ItemCollection to the collections/{collectionId}/items route. ([#367](https://github.com/stac-utils/stac-fastapi/pull/367))
99
* Add STAC API - Collections conformance class. ([383](https://github.com/stac-utils/stac-fastapi/pull/383))
10+
* Bulk item inserts for pgstac implementation. ([411](https://github.com/stac-utils/stac-fastapi/pull/411))
1011

1112
### Changed
1213

stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""bulk transactions extension."""
22
import abc
3-
from typing import Any, Dict, List, Optional
3+
from typing import Any, Callable, Dict, List, Optional, Type, Union
44

55
import attr
66
from fastapi import APIRouter, FastAPI
77
from pydantic import BaseModel
88

99
from stac_fastapi.api.models import create_request_model
10-
from stac_fastapi.api.routes import create_sync_endpoint
10+
from stac_fastapi.api.routes import create_async_endpoint, create_sync_endpoint
1111
from stac_fastapi.types.extension import ApiExtension
12+
from stac_fastapi.types.search import APIRequest
1213

1314

1415
class Items(BaseModel):
@@ -51,6 +52,24 @@ def bulk_item_insert(
5152
raise NotImplementedError
5253

5354

55+
@attr.s # type: ignore
56+
class AsyncBaseBulkTransactionsClient(abc.ABC):
57+
"""BulkTransactionsClient."""
58+
59+
@abc.abstractmethod
60+
async def bulk_item_insert(self, items: Items, **kwargs) -> str:
61+
"""Bulk creation of items.
62+
63+
Args:
64+
items: list of items.
65+
66+
Returns:
67+
Message indicating the status of the insert.
68+
69+
"""
70+
raise NotImplementedError
71+
72+
5473
@attr.s
5574
class BulkTransactionExtension(ApiExtension):
5675
"""Bulk Transaction Extension.
@@ -68,10 +87,24 @@ class BulkTransactionExtension(ApiExtension):
6887
6988
"""
7089

71-
client: BaseBulkTransactionsClient = attr.ib()
90+
client: Union[
91+
AsyncBaseBulkTransactionsClient, BaseBulkTransactionsClient
92+
] = attr.ib()
7293
conformance_classes: List[str] = attr.ib(default=list())
7394
schema_href: Optional[str] = attr.ib(default=None)
7495

96+
def _create_endpoint(
97+
self,
98+
func: Callable,
99+
request_type: Union[Type[APIRequest], Type[BaseModel], Dict],
100+
) -> Callable:
101+
"""Create a FastAPI endpoint."""
102+
if isinstance(self.client, AsyncBaseBulkTransactionsClient):
103+
return create_async_endpoint(func, request_type)
104+
elif isinstance(self.client, BaseBulkTransactionsClient):
105+
return create_sync_endpoint(func, request_type)
106+
raise NotImplementedError
107+
75108
def register(self, app: FastAPI) -> None:
76109
"""Register the extension with a FastAPI application.
77110
@@ -91,7 +124,7 @@ def register(self, app: FastAPI) -> None:
91124
response_model_exclude_unset=True,
92125
response_model_exclude_none=True,
93126
methods=["POST"],
94-
endpoint=create_sync_endpoint(
127+
endpoint=self._create_endpoint(
95128
self.client.bulk_item_insert, items_request_model
96129
),
97130
)

stac_fastapi/pgstac/stac_fastapi/pgstac/app.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
TokenPaginationExtension,
1111
TransactionExtension,
1212
)
13+
from stac_fastapi.extensions.third_party import BulkTransactionExtension
1314
from stac_fastapi.pgstac.config import Settings
1415
from stac_fastapi.pgstac.core import CoreCrudClient
1516
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
1617
from stac_fastapi.pgstac.extensions import QueryExtension
17-
from stac_fastapi.pgstac.transactions import TransactionsClient
18+
from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient
1819
from stac_fastapi.pgstac.types.search import PgstacSearch
1920

2021
settings = Settings()
@@ -29,6 +30,7 @@
2930
FieldsExtension(),
3031
TokenPaginationExtension(),
3132
ContextExtension(),
33+
BulkTransactionExtension(client=BulkTransactionsClient()),
3234
]
3335

3436
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)

stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
import attr
77
from starlette.responses import JSONResponse, Response
88

9+
from stac_fastapi.extensions.third_party.bulk_transactions import (
10+
AsyncBaseBulkTransactionsClient,
11+
Items,
12+
)
913
from stac_fastapi.pgstac.db import dbfunc
1014
from stac_fastapi.types import stac as stac_types
1115
from stac_fastapi.types.core import AsyncBaseTransactionsClient
@@ -71,3 +75,18 @@ async def delete_collection(
7175
pool = request.app.state.writepool
7276
await dbfunc(pool, "delete_collection", collection_id)
7377
return JSONResponse({"deleted collection": collection_id})
78+
79+
80+
@attr.s
81+
class BulkTransactionsClient(AsyncBaseBulkTransactionsClient):
82+
"""Postgres bulk transactions."""
83+
84+
async def bulk_item_insert(self, items: Items, **kwargs) -> str:
85+
"""Bulk item insertion using pgstac."""
86+
request = kwargs["request"]
87+
pool = request.app.state.writepool
88+
items = list(items.items.values())
89+
await dbfunc(pool, "create_items", items)
90+
91+
return_msg = f"Successfully added {len(items)} items."
92+
return return_msg

stac_fastapi/pgstac/tests/clients/test_postgres.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import uuid
2+
from copy import deepcopy
23
from typing import Callable
34

45
from stac_pydantic import Collection, Item
@@ -117,6 +118,32 @@ async def test_get_collection_items(app_client, load_test_collection, load_test_
117118
assert len(fc["features"]) == 5
118119

119120

121+
async def test_create_bulk_items(
122+
app_client, load_test_data: Callable, load_test_collection
123+
):
124+
coll = load_test_collection
125+
item = load_test_data("test_item.json")
126+
127+
items = {}
128+
for _ in range(2):
129+
_item = deepcopy(item)
130+
_item["id"] = str(uuid.uuid4())
131+
items[_item["id"]] = _item
132+
133+
payload = {"items": items}
134+
135+
resp = await app_client.post(
136+
f"/collections/{coll.id}/bulk_items",
137+
json=payload,
138+
)
139+
assert resp.status_code == 200
140+
assert resp.text == '"Successfully added 2 items."'
141+
142+
for item_id in items.keys():
143+
resp = await app_client.get(f"/collections/{coll.id}/items/{item_id}")
144+
assert resp.status_code == 200
145+
146+
120147
# TODO since right now puts implement upsert
121148
# test_create_collection_already_exists
122149
# test create_item_already_exists

stac_fastapi/pgstac/tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@
2121
TokenPaginationExtension,
2222
TransactionExtension,
2323
)
24+
from stac_fastapi.extensions.third_party import BulkTransactionExtension
2425
from stac_fastapi.pgstac.config import Settings
2526
from stac_fastapi.pgstac.core import CoreCrudClient
2627
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
2728
from stac_fastapi.pgstac.extensions import QueryExtension
28-
from stac_fastapi.pgstac.transactions import TransactionsClient
29+
from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient
2930
from stac_fastapi.pgstac.types.search import PgstacSearch
3031

3132
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
@@ -117,6 +118,7 @@ def api_client(request, pg):
117118
SortExtension(),
118119
FieldsExtension(),
119120
TokenPaginationExtension(),
121+
BulkTransactionExtension(client=BulkTransactionsClient()),
120122
]
121123
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
122124
api = StacApi(

0 commit comments

Comments
 (0)