|
| 1 | +import logging |
1 | 2 | import uuid
|
| 3 | +from contextlib import asynccontextmanager |
2 | 4 | from copy import deepcopy
|
3 |
| -from typing import Callable |
| 5 | +from typing import Callable, Literal |
| 6 | + |
| 7 | +import pytest |
| 8 | +from fastapi import Request |
4 | 9 |
|
5 | 10 | from stac_pydantic import Collection, Item
|
| 11 | +from stac_fastapi.pgstac.db import connect_to_db, close_db_connection, get_connection |
6 | 12 |
|
7 | 13 | # from tests.conftest import MockStarletteRequest
|
| 14 | +logger = logging.getLogger(__name__) |
8 | 15 |
|
9 | 16 |
|
10 | 17 | async def test_create_collection(app_client, load_test_data: Callable):
|
@@ -170,3 +177,33 @@ async def test_create_bulk_items(
|
170 | 177 |
|
171 | 178 | # for item in fc.features:
|
172 | 179 | # assert item.collection == coll.id
|
| 180 | + |
| 181 | + |
| 182 | +@asynccontextmanager |
| 183 | +async def custom_get_connection( |
| 184 | + request: Request, |
| 185 | + readwrite: Literal["r", "w"], |
| 186 | +): |
| 187 | + """An example of customizing the connection getter""" |
| 188 | + async with get_connection(request, readwrite) as conn: |
| 189 | + await conn.execute("SELECT set_config('api.test', 'added-config', false)") |
| 190 | + yield conn |
| 191 | + |
| 192 | + |
| 193 | +class TestDbConnect: |
| 194 | + @pytest.fixture |
| 195 | + async def app(self, api_client): |
| 196 | + logger.warn("Customizing app setup") |
| 197 | + await connect_to_db(api_client.app, custom_get_connection) |
| 198 | + yield api_client.app |
| 199 | + await close_db_connection(api_client.app) |
| 200 | + |
| 201 | + async def test_db_setup(self, api_client, app_client): |
| 202 | + @api_client.app.get(f"{api_client.router.prefix}/db-test") |
| 203 | + async def example_view(request: Request): |
| 204 | + async with request.app.state.get_connection(request, "r") as conn: |
| 205 | + return await conn.fetchval("SELECT current_setting('api.test', true)") |
| 206 | + |
| 207 | + response = await app_client.get("/db-test") |
| 208 | + assert response.status_code is 200 |
| 209 | + assert response.json() == "added-config" |
0 commit comments