Skip to content

Commit 659e46f

Browse files
authored
Pass request by name into methods (#44)
* fix: pass request by name into methods This makes #22 less breaking. * chore: update changelog
1 parent cbe55e0 commit 659e46f

File tree

3 files changed

+93
-7
lines changed

3 files changed

+93
-7
lines changed

CHANGES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## [Unreleased]
44

5+
### Fixed
6+
7+
- Pass `request` by name when calling endpoints from other endpoints ([#44](https://github.com/stac-utils/stac-fastapi-pgstac/pull/44))
8+
59
## [2.4.8] - 2023-06-08
610

711
### Changed

stac_fastapi/pgstac/core.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ async def _add_item_links(
213213
if settings.use_api_hydrate:
214214

215215
async def _get_base_item(collection_id: str) -> Dict[str, Any]:
216-
return await self._get_base_item(collection_id, request)
216+
return await self._get_base_item(collection_id, request=request)
217217

218218
base_item_cache = settings.base_item_cache(
219219
fetch_base_item=_get_base_item, request=request
@@ -267,7 +267,7 @@ async def item_collection(
267267
An ItemCollection.
268268
"""
269269
# If collection does not exist, NotFoundError wil be raised
270-
await self.get_collection(collection_id, request)
270+
await self.get_collection(collection_id, request=request)
271271

272272
base_args = {
273273
"collections": [collection_id],
@@ -285,7 +285,7 @@ async def item_collection(
285285
search_request = self.post_request_model(
286286
**clean,
287287
)
288-
item_collection = await self._search_base(search_request, request)
288+
item_collection = await self._search_base(search_request, request=request)
289289
links = await ItemCollectionLinks(
290290
collection_id=collection_id, request=request
291291
).get_links(extra_links=item_collection["links"])
@@ -307,12 +307,12 @@ async def get_item(
307307
Item.
308308
"""
309309
# If collection does not exist, NotFoundError wil be raised
310-
await self.get_collection(collection_id, request)
310+
await self.get_collection(collection_id, request=request)
311311

312312
search_request = self.post_request_model(
313313
ids=[item_id], collections=[collection_id], limit=1
314314
)
315-
item_collection = await self._search_base(search_request, request)
315+
item_collection = await self._search_base(search_request, request=request)
316316
if not item_collection["features"]:
317317
raise NotFoundError(
318318
f"Item {item_id} in Collection {collection_id} does not exist."
@@ -333,7 +333,7 @@ async def post_search(
333333
Returns:
334334
ItemCollection containing items which match the search criteria.
335335
"""
336-
item_collection = await self._search_base(search_request, request)
336+
item_collection = await self._search_base(search_request, request=request)
337337
return ItemCollection(**item_collection)
338338

339339
async def get_search(

tests/api/test_api.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
from datetime import datetime, timedelta
2-
from typing import Any, Dict, List
2+
from typing import Any, Callable, Coroutine, Dict, List, Optional, TypeVar
33
from urllib.parse import quote_plus
44

55
import orjson
66
import pytest
7+
from fastapi import Request
8+
from httpx import AsyncClient
79
from pystac import Collection, Extent, Item, SpatialExtent, TemporalExtent
10+
from stac_fastapi.api.app import StacApi
11+
from stac_fastapi.api.models import create_post_request_model
12+
from stac_fastapi.extensions.core import FieldsExtension, TransactionExtension
13+
from stac_fastapi.types import stac as stac_types
14+
15+
from stac_fastapi.pgstac.core import CoreCrudClient, Settings
16+
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
17+
from stac_fastapi.pgstac.transactions import TransactionsClient
18+
from stac_fastapi.pgstac.types.search import PgstacSearch
819

920
STAC_CORE_ROUTES = [
1021
"GET /",
@@ -622,3 +633,74 @@ async def search(query: Dict[str, Any]) -> List[Item]:
622633
}
623634
items = await search(query)
624635
assert len(items) == 10, items
636+
637+
638+
@pytest.mark.asyncio
639+
async def test_wrapped_function(load_test_data) -> None:
640+
# Ensure wrappers, e.g. Planetary Computer's rate limiting, work.
641+
# https://github.com/gadomski/planetary-computer-apis/blob/2719ccf6ead3e06de0784c39a2918d4d1811368b/pccommon/pccommon/redis.py#L205-L238
642+
643+
T = TypeVar("T")
644+
645+
def wrap() -> (
646+
Callable[
647+
[Callable[..., Coroutine[Any, Any, T]]],
648+
Callable[..., Coroutine[Any, Any, T]],
649+
]
650+
):
651+
def decorator(
652+
fn: Callable[..., Coroutine[Any, Any, T]]
653+
) -> Callable[..., Coroutine[Any, Any, T]]:
654+
async def _wrapper(*args: Any, **kwargs: Any) -> T:
655+
request: Optional[Request] = kwargs.get("request")
656+
if request:
657+
pass # This is where rate limiting would be applied
658+
else:
659+
raise ValueError(f"Missing request in {fn.__name__}")
660+
return await fn(*args, **kwargs)
661+
662+
return _wrapper
663+
664+
return decorator
665+
666+
class Client(CoreCrudClient):
667+
@wrap()
668+
async def get_collection(
669+
self, collection_id: str, request: Request, **kwargs
670+
) -> stac_types.Item:
671+
return await super().get_collection(
672+
collection_id, request=request, **kwargs
673+
)
674+
675+
settings = Settings(testing=True)
676+
extensions = [
677+
TransactionExtension(client=TransactionsClient(), settings=settings),
678+
FieldsExtension(),
679+
]
680+
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
681+
api = StacApi(
682+
client=Client(post_request_model=post_request_model),
683+
settings=settings,
684+
extensions=extensions,
685+
search_post_request_model=post_request_model,
686+
)
687+
app = api.app
688+
await connect_to_db(app)
689+
try:
690+
async with AsyncClient(app=app) as client:
691+
response = await client.post(
692+
"http://test/collections",
693+
json=load_test_data("test_collection.json"),
694+
)
695+
assert response.status_code == 200
696+
response = await client.post(
697+
"http://test/collections/test-collection/items",
698+
json=load_test_data("test_item.json"),
699+
)
700+
assert response.status_code == 200
701+
response = await client.get(
702+
"http://test/collections/test-collection/items/test-item"
703+
)
704+
assert response.status_code == 200
705+
finally:
706+
await close_db_connection(app)

0 commit comments

Comments
 (0)