|
1 | 1 | from datetime import datetime, timedelta
|
2 |
| -from typing import Any, Dict, List |
| 2 | +from typing import Any, Callable, Coroutine, Dict, List, Optional, TypeVar |
3 | 3 | from urllib.parse import quote_plus
|
4 | 4 |
|
5 | 5 | import orjson
|
6 | 6 | import pytest
|
| 7 | +from fastapi import Request |
| 8 | +from httpx import AsyncClient |
7 | 9 | from pystac import Collection, Extent, Item, SpatialExtent, TemporalExtent
|
| 10 | +from stac_fastapi.api.app import StacApi |
| 11 | +from stac_fastapi.types import stac as stac_types |
| 12 | + |
| 13 | +from stac_fastapi.pgstac.core import CoreCrudClient, Settings |
| 14 | +from stac_fastapi.pgstac.db import close_db_connection, connect_to_db |
8 | 15 |
|
9 | 16 | STAC_CORE_ROUTES = [
|
10 | 17 | "GET /",
|
@@ -622,3 +629,47 @@ async def search(query: Dict[str, Any]) -> List[Item]:
|
622 | 629 | }
|
623 | 630 | items = await search(query)
|
624 | 631 | assert len(items) == 10, items
|
| 632 | + |
| 633 | + |
| 634 | +@pytest.mark.asyncio |
| 635 | +async def test_wrapped_function() -> None: |
| 636 | + # Ensure wrappers, e.g. Planetary Computer's rate limiting, work. |
| 637 | + # https://github.com/gadomski/planetary-computer-apis/blob/2719ccf6ead3e06de0784c39a2918d4d1811368b/pccommon/pccommon/redis.py#L205-L238 |
| 638 | + |
| 639 | + T = TypeVar("T") |
| 640 | + |
| 641 | + def wrap() -> ( |
| 642 | + Callable[ |
| 643 | + [Callable[..., Coroutine[Any, Any, T]]], |
| 644 | + Callable[..., Coroutine[Any, Any, T]], |
| 645 | + ] |
| 646 | + ): |
| 647 | + def decorator( |
| 648 | + fn: Callable[..., Coroutine[Any, Any, T]] |
| 649 | + ) -> Callable[..., Coroutine[Any, Any, T]]: |
| 650 | + async def _wrapper(*args: Any, **kwargs: Any) -> T: |
| 651 | + request: Optional[Request] = kwargs.get("request") |
| 652 | + if request: |
| 653 | + pass # This is where rate limiting would be applied |
| 654 | + else: |
| 655 | + raise ValueError(f"Missing request in {fn.__name__}") |
| 656 | + return await fn(*args, **kwargs) |
| 657 | + |
| 658 | + return _wrapper |
| 659 | + |
| 660 | + return decorator |
| 661 | + |
| 662 | + class Client(CoreCrudClient): |
| 663 | + @wrap() |
| 664 | + async def all_collections(self, **kwargs) -> stac_types.Collections: |
| 665 | + return await super().all_collections(**kwargs) |
| 666 | + |
| 667 | + api = StacApi(client=Client(), settings=Settings(testing=True)) |
| 668 | + app = api.app |
| 669 | + await connect_to_db(app) |
| 670 | + try: |
| 671 | + async with AsyncClient(app=app) as client: |
| 672 | + response = await client.get("http://test/collections") |
| 673 | + assert response.status_code == 200 |
| 674 | + finally: |
| 675 | + await close_db_connection(app) |
0 commit comments