|
1 | 1 | from datetime import datetime, timedelta
|
2 |
| -from typing import Any, Dict, List |
| 2 | +from typing import Any, Callable, Coroutine, Dict, List, Optional, TypeVar |
3 | 3 | from urllib.parse import quote_plus
|
4 | 4 |
|
5 | 5 | import orjson
|
6 | 6 | import pytest
|
| 7 | +from fastapi import Request |
| 8 | +from httpx import AsyncClient |
7 | 9 | from pystac import Collection, Extent, Item, SpatialExtent, TemporalExtent
|
| 10 | +from stac_fastapi.api.app import StacApi |
| 11 | +from stac_fastapi.api.models import create_post_request_model |
| 12 | +from stac_fastapi.extensions.core import FieldsExtension, TransactionExtension |
| 13 | +from stac_fastapi.types import stac as stac_types |
| 14 | + |
| 15 | +from stac_fastapi.pgstac.core import CoreCrudClient, Settings |
| 16 | +from stac_fastapi.pgstac.db import close_db_connection, connect_to_db |
| 17 | +from stac_fastapi.pgstac.transactions import TransactionsClient |
| 18 | +from stac_fastapi.pgstac.types.search import PgstacSearch |
8 | 19 |
|
9 | 20 | STAC_CORE_ROUTES = [
|
10 | 21 | "GET /",
|
@@ -622,3 +633,74 @@ async def search(query: Dict[str, Any]) -> List[Item]:
|
622 | 633 | }
|
623 | 634 | items = await search(query)
|
624 | 635 | assert len(items) == 10, items
|
| 636 | + |
| 637 | + |
| 638 | +@pytest.mark.asyncio |
| 639 | +async def test_wrapped_function(load_test_data) -> None: |
| 640 | + # Ensure wrappers, e.g. Planetary Computer's rate limiting, work. |
| 641 | + # https://github.com/gadomski/planetary-computer-apis/blob/2719ccf6ead3e06de0784c39a2918d4d1811368b/pccommon/pccommon/redis.py#L205-L238 |
| 642 | + |
| 643 | + T = TypeVar("T") |
| 644 | + |
| 645 | + def wrap() -> ( |
| 646 | + Callable[ |
| 647 | + [Callable[..., Coroutine[Any, Any, T]]], |
| 648 | + Callable[..., Coroutine[Any, Any, T]], |
| 649 | + ] |
| 650 | + ): |
| 651 | + def decorator( |
| 652 | + fn: Callable[..., Coroutine[Any, Any, T]] |
| 653 | + ) -> Callable[..., Coroutine[Any, Any, T]]: |
| 654 | + async def _wrapper(*args: Any, **kwargs: Any) -> T: |
| 655 | + request: Optional[Request] = kwargs.get("request") |
| 656 | + if request: |
| 657 | + pass # This is where rate limiting would be applied |
| 658 | + else: |
| 659 | + raise ValueError(f"Missing request in {fn.__name__}") |
| 660 | + return await fn(*args, **kwargs) |
| 661 | + |
| 662 | + return _wrapper |
| 663 | + |
| 664 | + return decorator |
| 665 | + |
| 666 | + class Client(CoreCrudClient): |
| 667 | + @wrap() |
| 668 | + async def get_collection( |
| 669 | + self, collection_id: str, request: Request, **kwargs |
| 670 | + ) -> stac_types.Item: |
| 671 | + return await super().get_collection( |
| 672 | + collection_id, request=request, **kwargs |
| 673 | + ) |
| 674 | + |
| 675 | + settings = Settings(testing=True) |
| 676 | + extensions = [ |
| 677 | + TransactionExtension(client=TransactionsClient(), settings=settings), |
| 678 | + FieldsExtension(), |
| 679 | + ] |
| 680 | + post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) |
| 681 | + api = StacApi( |
| 682 | + client=Client(post_request_model=post_request_model), |
| 683 | + settings=settings, |
| 684 | + extensions=extensions, |
| 685 | + search_post_request_model=post_request_model, |
| 686 | + ) |
| 687 | + app = api.app |
| 688 | + await connect_to_db(app) |
| 689 | + try: |
| 690 | + async with AsyncClient(app=app) as client: |
| 691 | + response = await client.post( |
| 692 | + "http://test/collections", |
| 693 | + json=load_test_data("test_collection.json"), |
| 694 | + ) |
| 695 | + assert response.status_code == 200 |
| 696 | + response = await client.post( |
| 697 | + "http://test/collections/test-collection/items", |
| 698 | + json=load_test_data("test_item.json"), |
| 699 | + ) |
| 700 | + assert response.status_code == 200 |
| 701 | + response = await client.get( |
| 702 | + "http://test/collections/test-collection/items/test-item" |
| 703 | + ) |
| 704 | + assert response.status_code == 200 |
| 705 | + finally: |
| 706 | + await close_db_connection(app) |
0 commit comments