Skip to content

Commit aa4bcbd

Browse files
committed
Add collection_id path parameter and check against Item collection property
1 parent 0483406 commit aa4bcbd

File tree

4 files changed

+82
-8
lines changed

4 files changed

+82
-8
lines changed

stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Callable, List, Optional, Type, Union
33

44
import attr
5-
from fastapi import APIRouter, FastAPI
5+
from fastapi import APIRouter, Body, FastAPI
66
from pydantic import BaseModel
77
from stac_pydantic import Collection, Item
88
from starlette.responses import JSONResponse, Response
@@ -15,6 +15,13 @@
1515
from stac_fastapi.types.extension import ApiExtension
1616

1717

18+
@attr.s
19+
class PostOrPutItem(CollectionUri):
20+
"""Create or update Item."""
21+
22+
item: stac_types.Item = attr.ib(default=Body())
23+
24+
1825
@attr.s
1926
class TransactionExtension(ApiExtension):
2027
"""Transaction Extension.
@@ -77,7 +84,7 @@ def register_create_item(self):
7784
response_model_exclude_unset=True,
7885
response_model_exclude_none=True,
7986
methods=["POST"],
80-
endpoint=self._create_endpoint(self.client.create_item, stac_types.Item),
87+
endpoint=self._create_endpoint(self.client.create_item, PostOrPutItem),
8188
)
8289

8390
def register_update_item(self):
@@ -90,7 +97,7 @@ def register_update_item(self):
9097
response_model_exclude_unset=True,
9198
response_model_exclude_none=True,
9299
methods=["PUT"],
93-
endpoint=self._create_endpoint(self.client.update_item, stac_types.Item),
100+
endpoint=self._create_endpoint(self.client.update_item, PostOrPutItem),
94101
)
95102

96103
def register_delete_item(self):

stac_fastapi/pgstac/stac_fastapi/pgstac/transactions.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Optional, Union
55

66
import attr
7+
from fastapi import HTTPException
78
from starlette.responses import JSONResponse, Response
89

910
from stac_fastapi.extensions.third_party.bulk_transactions import (
@@ -23,18 +24,32 @@ class TransactionsClient(AsyncBaseTransactionsClient):
2324
"""Transactions extension specific CRUD operations."""
2425

2526
async def create_item(
26-
self, item: stac_types.Item, **kwargs
27+
self, collection_id: str, item: stac_types.Item, **kwargs
2728
) -> Optional[Union[stac_types.Item, Response]]:
2829
"""Create item."""
30+
item_collection_id = item.get("collection")
31+
if item_collection_id is not None and collection_id != item_collection_id:
32+
raise HTTPException(
33+
status_code=409,
34+
detail=f"Collection ID from path parameter ({collection_id}) does not match Collection ID from Item ({item_collection_id})",
35+
)
36+
item["collection"] = collection_id
2937
request = kwargs["request"]
3038
pool = request.app.state.writepool
3139
await dbfunc(pool, "create_item", item)
3240
return item
3341

3442
async def update_item(
35-
self, item: stac_types.Item, **kwargs
43+
self, collection_id: str, item: stac_types.Item, **kwargs
3644
) -> Optional[Union[stac_types.Item, Response]]:
3745
"""Update item."""
46+
item_collection_id = item.get("collection")
47+
if item_collection_id is not None and collection_id != item_collection_id:
48+
raise HTTPException(
49+
status_code=409,
50+
detail=f"Collection ID from path parameter ({collection_id}) does not match Collection ID from Item ({item_collection_id})",
51+
)
52+
item["collection"] = collection_id
3853
request = kwargs["request"]
3954
pool = request.app.state.writepool
4055
await dbfunc(pool, "update_item", item)

stac_fastapi/pgstac/tests/resources/test_item.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import json
2+
import random
23
import uuid
34
from datetime import timedelta
5+
from string import ascii_letters
46
from typing import Callable
57
from urllib.parse import parse_qs, urljoin, urlparse
68

@@ -79,6 +81,24 @@ async def test_create_item(app_client, load_test_data: Callable, load_test_colle
7981
assert in_item.dict(exclude={"links"}) == get_item.dict(exclude={"links"})
8082

8183

84+
async def test_create_item_mismatched_collection_id(
85+
app_client, load_test_data: Callable, load_test_collection
86+
):
87+
# If the collection_id path parameter and the Item's "collection" property do not match, a 409 response should
88+
# be returned.
89+
coll = load_test_collection
90+
91+
in_json = load_test_data("test_item.json")
92+
in_json["collection"] = random.choice(ascii_letters)
93+
assert in_json["collection"] != coll.id
94+
95+
resp = await app_client.post(
96+
f"/collections/{coll.id}/items",
97+
json=in_json,
98+
)
99+
assert resp.status_code == 409
100+
101+
82102
async def test_fetches_valid_item(
83103
app_client, load_test_data: Callable, load_test_collection
84104
):
@@ -126,6 +146,23 @@ async def test_update_item(
126146
assert get_item.properties.description == "Update Test"
127147

128148

149+
async def test_update_item_mismatched_collection_id(
150+
app_client, load_test_data: Callable, load_test_collection, load_test_item
151+
) -> None:
152+
coll = load_test_collection
153+
154+
in_json = load_test_data("test_item.json")
155+
156+
in_json["collection"] = random.choice(ascii_letters)
157+
assert in_json["collection"] != coll.id
158+
159+
resp = await app_client.put(
160+
f"/collections/{coll.id}/items",
161+
json=in_json,
162+
)
163+
assert resp.status_code == 409
164+
165+
129166
async def test_delete_item(
130167
app_client, load_test_data: Callable, load_test_collection, load_test_item
131168
):
@@ -201,7 +238,10 @@ async def test_create_item_missing_collection(
201238
item["collection"] = None
202239

203240
resp = await app_client.post(f"/collections/{coll.id}/items", json=item)
204-
assert resp.status_code == 424
241+
assert resp.status_code == 200
242+
243+
post_item = resp.json()
244+
assert post_item["collection"] == coll.id
205245

206246

207247
async def test_update_new_item(
@@ -223,7 +263,10 @@ async def test_update_item_missing_collection(
223263
item.collection = None
224264

225265
resp = await app_client.put(f"/collections/{coll.id}/items", content=item.json())
226-
assert resp.status_code == 424
266+
assert resp.status_code == 200
267+
268+
put_item = resp.json()
269+
assert put_item["collection"] == coll.id
227270

228271

229272
async def test_pagination(app_client, load_test_data, load_test_collection):

stac_fastapi/types/stac_fastapi/types/stac.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
"""STAC types."""
2-
from typing import Any, Dict, List, Optional, TypedDict, Union
2+
import sys
3+
from typing import Any, Dict, List, Optional, Union
4+
5+
# Avoids a Pydantic error:
6+
# TypeError: You should use `typing_extensions.TypedDict` instead of `typing.TypedDict` with Python < 3.9.2.
7+
# Without it, there is no way to differentiate required and optional fields when subclassed.
8+
if sys.version_info < (3, 9, 2):
9+
from typing_extensions import TypedDict
10+
else:
11+
from typing import TypedDict
312

413
NumType = Union[float, int]
514

0 commit comments

Comments
 (0)