Skip to content

Commit 625564d

Browse files
committed
Add test for customizing the connection_getter
1 parent 72dd57d commit 625564d

File tree

1 file changed

+38
-1
lines changed

1 file changed

+38
-1
lines changed

stac_fastapi/pgstac/tests/clients/test_postgres.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1+
import logging
12
import uuid
3+
from contextlib import asynccontextmanager
24
from copy import deepcopy
3-
from typing import Callable
5+
from typing import Callable, Literal
6+
7+
import pytest
8+
from fastapi import Request
49

510
from stac_pydantic import Collection, Item
11+
from stac_fastapi.pgstac.db import connect_to_db, close_db_connection, get_connection
612

713
# from tests.conftest import MockStarletteRequest
14+
logger = logging.getLogger(__name__)
815

916

1017
async def test_create_collection(app_client, load_test_data: Callable):
@@ -170,3 +177,33 @@ async def test_create_bulk_items(
170177

171178
# for item in fc.features:
172179
# assert item.collection == coll.id
180+
181+
182+
@asynccontextmanager
183+
async def custom_get_connection(
184+
request: Request,
185+
readwrite: Literal["r", "w"],
186+
):
187+
"""An example of customizing the connection getter"""
188+
async with get_connection(request, readwrite) as conn:
189+
await conn.execute("SELECT set_config('api.test', 'added-config', false)")
190+
yield conn
191+
192+
193+
class TestDbConnect:
194+
@pytest.fixture
195+
async def app(self, api_client):
196+
logger.warn("Customizing app setup")
197+
await connect_to_db(api_client.app, custom_get_connection)
198+
yield api_client.app
199+
await close_db_connection(api_client.app)
200+
201+
async def test_db_setup(self, api_client, app_client):
202+
@api_client.app.get(f"{api_client.router.prefix}/db-test")
203+
async def example_view(request: Request):
204+
async with request.app.state.get_connection(request, "r") as conn:
205+
return await conn.fetchval("SELECT current_setting('api.test', true)")
206+
207+
response = await app_client.get("/db-test")
208+
assert response.status_code is 200
209+
assert response.json() == "added-config"

0 commit comments

Comments
 (0)