Skip to content

Commit c844034

Browse files
moradologylossyrob
andauthored
Set content-type to geojson for search results (#288)
* Set content-type to geojson for search results * Update changelog * Test for ORJSON availability Co-authored-by: Rob Emanuele <[email protected]>
1 parent 185b09c commit c844034

File tree

5 files changed

+74
-17
lines changed

5 files changed

+74
-17
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
### Fixed
1212

13+
* Content-type response headers for the /search endpoint now reflect the geojson response expected in the STAC api spec ([#220](https://github.com/stac-utils/stac-fastapi/issues/220)
1314
* The minimum `limit` value for searches is now 1 ([#296](https://github.com/stac-utils/stac-fastapi/pull/296))
1415
* Links stored with Collections and Items (e.g. license links) are now returned with those STAC objects ([#282](https://github.com/stac-utils/stac-fastapi/pull/282))
1516
* Content-type response headers for the /api endpoint now reflect those expected in the STAC api spec ([#287](https://github.com/stac-utils/stac-fastapi/pull/287))

stac_fastapi/api/stac_fastapi/api/app.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
APIRequest,
1818
CollectionUri,
1919
EmptyRequest,
20+
GeoJSONResponse,
2021
ItemCollectionUri,
2122
ItemUri,
2223
SearchGetRequest,
@@ -96,17 +97,16 @@ def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension]
9697
return None
9798

9899
def _create_endpoint(
99-
self, func: Callable, request_type: Union[Type[APIRequest], Type[BaseModel]]
100+
self,
101+
func: Callable,
102+
request_type: Union[Type[APIRequest], Type[BaseModel]],
103+
resp_class: Type[Response],
100104
) -> Callable:
101105
"""Create a FastAPI endpoint."""
102106
if isinstance(self.client, AsyncBaseCoreClient):
103-
return create_async_endpoint(
104-
func, request_type, response_class=self.response_class
105-
)
107+
return create_async_endpoint(func, request_type, response_class=resp_class)
106108
elif isinstance(self.client, BaseCoreClient):
107-
return create_sync_endpoint(
108-
func, request_type, response_class=self.response_class
109-
)
109+
return create_sync_endpoint(func, request_type, response_class=resp_class)
110110
raise NotImplementedError
111111

112112
def register_landing_page(self):
@@ -125,7 +125,9 @@ def register_landing_page(self):
125125
response_model_exclude_unset=False,
126126
response_model_exclude_none=True,
127127
methods=["GET"],
128-
endpoint=self._create_endpoint(self.client.landing_page, EmptyRequest),
128+
endpoint=self._create_endpoint(
129+
self.client.landing_page, EmptyRequest, self.response_class
130+
),
129131
)
130132

131133
def register_conformance_classes(self):
@@ -144,7 +146,9 @@ def register_conformance_classes(self):
144146
response_model_exclude_unset=True,
145147
response_model_exclude_none=True,
146148
methods=["GET"],
147-
endpoint=self._create_endpoint(self.client.conformance, EmptyRequest),
149+
endpoint=self._create_endpoint(
150+
self.client.conformance, EmptyRequest, self.response_class
151+
),
148152
)
149153

150154
def register_get_item(self):
@@ -161,7 +165,9 @@ def register_get_item(self):
161165
response_model_exclude_unset=True,
162166
response_model_exclude_none=True,
163167
methods=["GET"],
164-
endpoint=self._create_endpoint(self.client.get_item, ItemUri),
168+
endpoint=self._create_endpoint(
169+
self.client.get_item, ItemUri, self.response_class
170+
),
165171
)
166172

167173
def register_post_search(self):
@@ -178,12 +184,12 @@ def register_post_search(self):
178184
response_model=(ItemCollection if not fields_ext else None)
179185
if self.settings.enable_response_models
180186
else None,
181-
response_class=self.response_class,
187+
response_class=GeoJSONResponse,
182188
response_model_exclude_unset=True,
183189
response_model_exclude_none=True,
184190
methods=["POST"],
185191
endpoint=self._create_endpoint(
186-
self.client.post_search, search_request_model
192+
self.client.post_search, search_request_model, GeoJSONResponse
187193
),
188194
)
189195

@@ -200,12 +206,12 @@ def register_get_search(self):
200206
response_model=(ItemCollection if not fields_ext else None)
201207
if self.settings.enable_response_models
202208
else None,
203-
response_class=self.response_class,
209+
response_class=GeoJSONResponse,
204210
response_model_exclude_unset=True,
205211
response_model_exclude_none=True,
206212
methods=["GET"],
207213
endpoint=self._create_endpoint(
208-
self.client.get_search, self.search_get_request
214+
self.client.get_search, self.search_get_request, GeoJSONResponse
209215
),
210216
)
211217

@@ -225,7 +231,9 @@ def register_get_collections(self):
225231
response_model_exclude_unset=True,
226232
response_model_exclude_none=True,
227233
methods=["GET"],
228-
endpoint=self._create_endpoint(self.client.all_collections, EmptyRequest),
234+
endpoint=self._create_endpoint(
235+
self.client.all_collections, EmptyRequest, self.response_class
236+
),
229237
)
230238

231239
def register_get_collection(self):
@@ -242,7 +250,9 @@ def register_get_collection(self):
242250
response_model_exclude_unset=True,
243251
response_model_exclude_none=True,
244252
methods=["GET"],
245-
endpoint=self._create_endpoint(self.client.get_collection, CollectionUri),
253+
endpoint=self._create_endpoint(
254+
self.client.get_collection, CollectionUri, self.response_class
255+
),
246256
)
247257

248258
def register_get_item_collection(self):
@@ -262,7 +272,9 @@ def register_get_item_collection(self):
262272
response_model_exclude_none=True,
263273
methods=["GET"],
264274
endpoint=self._create_endpoint(
265-
self.client.item_collection, self.item_collection_uri
275+
self.client.item_collection,
276+
self.item_collection_uri,
277+
self.response_class,
266278
),
267279
)
268280

stac_fastapi/api/stac_fastapi/api/models.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""api request/response models."""
22

33
import abc
4+
import importlib
45
from typing import Dict, Optional, Type, Union
56

67
import attr
@@ -127,3 +128,22 @@ def kwargs(self) -> Dict:
127128
"fields": self.fields.split(",") if self.fields else self.fields,
128129
"sortby": self.sortby.split(",") if self.sortby else self.sortby,
129130
}
131+
132+
133+
# Test for ORJSON and use it rather than stdlib JSON where supported
134+
if importlib.util.find_spec("orjson") is not None:
135+
from fastapi.responses import ORJSONResponse
136+
137+
class GeoJSONResponse(ORJSONResponse):
138+
"""JSON with custom, vendor content-type."""
139+
140+
media_type = "application/geo+json"
141+
142+
143+
else:
144+
from starlette.responses import JSONResponse
145+
146+
class GeoJSONResponse(JSONResponse):
147+
"""JSON with custom, vendor content-type."""
148+
149+
media_type = "application/geo+json"

stac_fastapi/pgstac/tests/api/test_api.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,19 @@
2323
]
2424

2525

26+
@pytest.mark.asyncio
27+
async def test_post_search_content_type(app_client):
28+
params = {"limit": 1}
29+
resp = await app_client.post("search", json=params)
30+
assert resp.headers["content-type"] == "application/geo+json"
31+
32+
33+
@pytest.mark.asyncio
34+
async def test_get_search_content_type(app_client):
35+
resp = await app_client.get("search")
36+
assert resp.headers["content-type"] == "application/geo+json"
37+
38+
2639
@pytest.mark.asyncio
2740
async def test_api_headers(app_client):
2841
resp = await app_client.get("/api")

stac_fastapi/sqlalchemy/tests/api/test_api.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@
2323
]
2424

2525

26+
def test_post_search_content_type(app_client):
27+
params = {"limit": 1}
28+
resp = app_client.post("search", json=params)
29+
assert resp.headers["content-type"] == "application/geo+json"
30+
31+
32+
def test_get_search_content_type(app_client):
33+
resp = app_client.get("search")
34+
assert resp.headers["content-type"] == "application/geo+json"
35+
36+
2637
def test_api_headers(app_client):
2738
resp = app_client.get("/api")
2839
assert (

0 commit comments

Comments
 (0)