Skip to content

Commit d1daa92

Browse files
remove create_sync_endpoint, run sync functions in background thread (#471)
* remove create_sync_endpoint, run sync functions in background thread * remove _create_endpoint instance methods * pass response class to filter extension endpoints
1 parent 9266980 commit d1daa92

File tree

5 files changed

+51
-144
lines changed

5 files changed

+51
-144
lines changed

stac_fastapi/api/stac_fastapi/api/app.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
"""fastapi app creation."""
2-
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
2+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
33

44
import attr
55
from brotli_asgi import BrotliMiddleware
66
from fastapi import APIRouter, FastAPI
77
from fastapi.openapi.utils import get_openapi
88
from fastapi.params import Depends
9-
from pydantic import BaseModel
109
from stac_pydantic import Collection, Item, ItemCollection
1110
from stac_pydantic.api import ConformanceClasses, LandingPage
1211
from stac_pydantic.api.collections import Collections
@@ -16,7 +15,6 @@
1615
from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers
1716
from stac_fastapi.api.middleware import CORSMiddleware, ProxyHeaderMiddleware
1817
from stac_fastapi.api.models import (
19-
APIRequest,
2018
CollectionUri,
2119
EmptyRequest,
2220
GeoJSONResponse,
@@ -25,12 +23,7 @@
2523
create_request_model,
2624
)
2725
from stac_fastapi.api.openapi import update_openapi
28-
from stac_fastapi.api.routes import (
29-
Scope,
30-
add_route_dependencies,
31-
create_async_endpoint,
32-
create_sync_endpoint,
33-
)
26+
from stac_fastapi.api.routes import Scope, add_route_dependencies, create_async_endpoint
3427

3528
# TODO: make this module not depend on `stac_fastapi.extensions`
3629
from stac_fastapi.extensions.core import FieldsExtension, TokenPaginationExtension
@@ -113,19 +106,6 @@ def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension]
113106
return ext
114107
return None
115108

116-
def _create_endpoint(
117-
self,
118-
func: Callable,
119-
request_type: Union[Type[APIRequest], Type[BaseModel]],
120-
resp_class: Type[Response],
121-
) -> Callable:
122-
"""Create a FastAPI endpoint."""
123-
if isinstance(self.client, AsyncBaseCoreClient):
124-
return create_async_endpoint(func, request_type, response_class=resp_class)
125-
elif isinstance(self.client, BaseCoreClient):
126-
return create_sync_endpoint(func, request_type, response_class=resp_class)
127-
raise NotImplementedError
128-
129109
def register_landing_page(self):
130110
"""Register landing page (GET /).
131111
@@ -142,7 +122,7 @@ def register_landing_page(self):
142122
response_model_exclude_unset=False,
143123
response_model_exclude_none=True,
144124
methods=["GET"],
145-
endpoint=self._create_endpoint(
125+
endpoint=create_async_endpoint(
146126
self.client.landing_page, EmptyRequest, self.response_class
147127
),
148128
)
@@ -163,7 +143,7 @@ def register_conformance_classes(self):
163143
response_model_exclude_unset=True,
164144
response_model_exclude_none=True,
165145
methods=["GET"],
166-
endpoint=self._create_endpoint(
146+
endpoint=create_async_endpoint(
167147
self.client.conformance, EmptyRequest, self.response_class
168148
),
169149
)
@@ -182,7 +162,7 @@ def register_get_item(self):
182162
response_model_exclude_unset=True,
183163
response_model_exclude_none=True,
184164
methods=["GET"],
185-
endpoint=self._create_endpoint(
165+
endpoint=create_async_endpoint(
186166
self.client.get_item, ItemUri, self.response_class
187167
),
188168
)
@@ -204,7 +184,7 @@ def register_post_search(self):
204184
response_model_exclude_unset=True,
205185
response_model_exclude_none=True,
206186
methods=["POST"],
207-
endpoint=self._create_endpoint(
187+
endpoint=create_async_endpoint(
208188
self.client.post_search, self.search_post_request_model, GeoJSONResponse
209189
),
210190
)
@@ -226,7 +206,7 @@ def register_get_search(self):
226206
response_model_exclude_unset=True,
227207
response_model_exclude_none=True,
228208
methods=["GET"],
229-
endpoint=self._create_endpoint(
209+
endpoint=create_async_endpoint(
230210
self.client.get_search, self.search_get_request_model, GeoJSONResponse
231211
),
232212
)
@@ -247,7 +227,7 @@ def register_get_collections(self):
247227
response_model_exclude_unset=True,
248228
response_model_exclude_none=True,
249229
methods=["GET"],
250-
endpoint=self._create_endpoint(
230+
endpoint=create_async_endpoint(
251231
self.client.all_collections, EmptyRequest, self.response_class
252232
),
253233
)
@@ -266,7 +246,7 @@ def register_get_collection(self):
266246
response_model_exclude_unset=True,
267247
response_model_exclude_none=True,
268248
methods=["GET"],
269-
endpoint=self._create_endpoint(
249+
endpoint=create_async_endpoint(
270250
self.client.get_collection, CollectionUri, self.response_class
271251
),
272252
)
@@ -297,7 +277,7 @@ def register_get_item_collection(self):
297277
response_model_exclude_unset=True,
298278
response_model_exclude_none=True,
299279
methods=["GET"],
300-
endpoint=self._create_endpoint(
280+
endpoint=create_async_endpoint(
301281
self.client.item_collection, request_model, self.response_class
302282
),
303283
)

stac_fastapi/api/stac_fastapi/api/routes.py

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""route factories."""
2+
import functools
3+
import inspect
24
from typing import Any, Callable, Dict, List, Optional, Type, TypedDict, Union
35

46
from fastapi import Depends, params
57
from fastapi.dependencies.utils import get_parameterless_sub_dependant
68
from pydantic import BaseModel
9+
from starlette.concurrency import run_in_threadpool
710
from starlette.requests import Request
811
from starlette.responses import JSONResponse, Response
912
from starlette.routing import BaseRoute, Match
@@ -21,12 +24,28 @@ def _wrap_response(resp: Any, response_class: Type[Response]) -> Response:
2124
return Response(status_code=HTTP_204_NO_CONTENT)
2225

2326

27+
def sync_to_async(func):
28+
"""Run synchronous function asynchronously in a background thread."""
29+
30+
@functools.wraps(func)
31+
async def run(*args, **kwargs):
32+
return await run_in_threadpool(func, *args, **kwargs)
33+
34+
return run
35+
36+
2437
def create_async_endpoint(
2538
func: Callable,
2639
request_model: Union[Type[APIRequest], Type[BaseModel], Dict],
2740
response_class: Type[Response] = JSONResponse,
2841
):
29-
"""Wrap a coroutine in another coroutine which may be used to create a FastAPI endpoint."""
42+
"""Wrap a function in a coroutine which may be used to create a FastAPI endpoint.
43+
44+
Synchronous functions are executed asynchronously using a background thread.
45+
"""
46+
if not inspect.iscoroutinefunction(func):
47+
func = sync_to_async(func)
48+
3049
if issubclass(request_model, APIRequest):
3150

3251
async def _endpoint(
@@ -63,44 +82,6 @@ async def _endpoint(
6382
return _endpoint
6483

6584

66-
def create_sync_endpoint(
67-
func: Callable,
68-
request_model: Union[Type[APIRequest], Type[BaseModel], Dict],
69-
response_class: Type[Response] = JSONResponse,
70-
):
71-
"""Wrap a function in another function which may be used to create a FastAPI endpoint."""
72-
if issubclass(request_model, APIRequest):
73-
74-
def _endpoint(
75-
request: Request,
76-
request_data: request_model = Depends(), # type:ignore
77-
):
78-
"""Endpoint."""
79-
return _wrap_response(
80-
func(request=request, **request_data.kwargs()), response_class
81-
)
82-
83-
elif issubclass(request_model, BaseModel):
84-
85-
def _endpoint(
86-
request: Request,
87-
request_data: request_model, # type:ignore
88-
):
89-
"""Endpoint."""
90-
return _wrap_response(func(request_data, request=request), response_class)
91-
92-
else:
93-
94-
def _endpoint(
95-
request: Request,
96-
request_data: Dict[str, Any], # type:ignore
97-
):
98-
"""Endpoint."""
99-
return _wrap_response(func(request_data, request=request), response_class)
100-
101-
return _endpoint
102-
103-
10485
class Scope(TypedDict, total=False):
10586
"""More strict version of Starlette's Scope."""
10687

stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
11
# encoding: utf-8
22
"""Filter Extension."""
33
from enum import Enum
4-
from typing import Callable, List, Type, Union
4+
from typing import List, Type, Union
55

66
import attr
77
from fastapi import APIRouter, FastAPI
88
from starlette.responses import Response
99

10-
from stac_fastapi.api.models import (
11-
APIRequest,
12-
CollectionUri,
13-
EmptyRequest,
14-
JSONSchemaResponse,
15-
)
16-
from stac_fastapi.api.routes import create_async_endpoint, create_sync_endpoint
10+
from stac_fastapi.api.models import CollectionUri, EmptyRequest, JSONSchemaResponse
11+
from stac_fastapi.api.routes import create_async_endpoint
1712
from stac_fastapi.types.core import AsyncBaseFiltersClient, BaseFiltersClient
1813
from stac_fastapi.types.extension import ApiExtension
1914

@@ -80,24 +75,6 @@ class FilterExtension(ApiExtension):
8075
router: APIRouter = attr.ib(factory=APIRouter)
8176
response_class: Type[Response] = attr.ib(default=JSONSchemaResponse)
8277

83-
def _create_endpoint(
84-
self,
85-
func: Callable,
86-
request_type: Union[
87-
Type[APIRequest],
88-
],
89-
) -> Callable:
90-
"""Create a FastAPI endpoint."""
91-
if isinstance(self.client, AsyncBaseFiltersClient):
92-
return create_async_endpoint(
93-
func, request_type, response_class=self.response_class
94-
)
95-
if isinstance(self.client, BaseFiltersClient):
96-
return create_sync_endpoint(
97-
func, request_type, response_class=self.response_class
98-
)
99-
raise NotImplementedError
100-
10178
def register(self, app: FastAPI) -> None:
10279
"""Register the extension with a FastAPI application.
10380
@@ -112,12 +89,16 @@ def register(self, app: FastAPI) -> None:
11289
name="Queryables",
11390
path="/queryables",
11491
methods=["GET"],
115-
endpoint=self._create_endpoint(self.client.get_queryables, EmptyRequest),
92+
endpoint=create_async_endpoint(
93+
self.client.get_queryables, EmptyRequest, self.response_class
94+
),
11695
)
11796
self.router.add_api_route(
11897
name="Collection Queryables",
11998
path="/collections/{collection_id}/queryables",
12099
methods=["GET"],
121-
endpoint=self._create_endpoint(self.client.get_queryables, CollectionUri),
100+
endpoint=create_async_endpoint(
101+
self.client.get_queryables, CollectionUri, self.response_class
102+
),
122103
)
123104
app.include_router(self.router, tags=["Filter Extension"])

stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
"""transaction extension."""
2-
from typing import Callable, List, Optional, Type, Union
2+
from typing import List, Optional, Type, Union
33

44
import attr
55
from fastapi import APIRouter, Body, FastAPI
6-
from pydantic import BaseModel
76
from stac_pydantic import Collection, Item
87
from starlette.responses import JSONResponse, Response
98

10-
from stac_fastapi.api.models import APIRequest, CollectionUri, ItemUri
11-
from stac_fastapi.api.routes import create_async_endpoint, create_sync_endpoint
9+
from stac_fastapi.api.models import CollectionUri, ItemUri
10+
from stac_fastapi.api.routes import create_async_endpoint
1211
from stac_fastapi.types import stac as stac_types
1312
from stac_fastapi.types.config import ApiSettings
1413
from stac_fastapi.types.core import AsyncBaseTransactionsClient, BaseTransactionsClient
@@ -60,27 +59,6 @@ class TransactionExtension(ApiExtension):
6059
router: APIRouter = attr.ib(factory=APIRouter)
6160
response_class: Type[Response] = attr.ib(default=JSONResponse)
6261

63-
def _create_endpoint(
64-
self,
65-
func: Callable,
66-
request_type: Union[
67-
Type[APIRequest],
68-
Type[BaseModel],
69-
Type[stac_types.Item],
70-
Type[stac_types.Collection],
71-
],
72-
) -> Callable:
73-
"""Create a FastAPI endpoint."""
74-
if isinstance(self.client, AsyncBaseTransactionsClient):
75-
return create_async_endpoint(
76-
func, request_type, response_class=self.response_class
77-
)
78-
elif isinstance(self.client, BaseTransactionsClient):
79-
return create_sync_endpoint(
80-
func, request_type, response_class=self.response_class
81-
)
82-
raise NotImplementedError
83-
8462
def register_create_item(self):
8563
"""Register create item endpoint (POST /collections/{collection_id}/items)."""
8664
self.router.add_api_route(
@@ -91,7 +69,7 @@ def register_create_item(self):
9169
response_model_exclude_unset=True,
9270
response_model_exclude_none=True,
9371
methods=["POST"],
94-
endpoint=self._create_endpoint(self.client.create_item, PostItem),
72+
endpoint=create_async_endpoint(self.client.create_item, PostItem),
9573
)
9674

9775
def register_update_item(self):
@@ -104,7 +82,7 @@ def register_update_item(self):
10482
response_model_exclude_unset=True,
10583
response_model_exclude_none=True,
10684
methods=["PUT"],
107-
endpoint=self._create_endpoint(self.client.update_item, PutItem),
85+
endpoint=create_async_endpoint(self.client.update_item, PutItem),
10886
)
10987

11088
def register_delete_item(self):
@@ -117,7 +95,7 @@ def register_delete_item(self):
11795
response_model_exclude_unset=True,
11896
response_model_exclude_none=True,
11997
methods=["DELETE"],
120-
endpoint=self._create_endpoint(self.client.delete_item, ItemUri),
98+
endpoint=create_async_endpoint(self.client.delete_item, ItemUri),
12199
)
122100

123101
def register_create_collection(self):
@@ -130,7 +108,7 @@ def register_create_collection(self):
130108
response_model_exclude_unset=True,
131109
response_model_exclude_none=True,
132110
methods=["POST"],
133-
endpoint=self._create_endpoint(
111+
endpoint=create_async_endpoint(
134112
self.client.create_collection, stac_types.Collection
135113
),
136114
)
@@ -145,7 +123,7 @@ def register_update_collection(self):
145123
response_model_exclude_unset=True,
146124
response_model_exclude_none=True,
147125
methods=["PUT"],
148-
endpoint=self._create_endpoint(
126+
endpoint=create_async_endpoint(
149127
self.client.update_collection, stac_types.Collection
150128
),
151129
)
@@ -160,7 +138,7 @@ def register_delete_collection(self):
160138
response_model_exclude_unset=True,
161139
response_model_exclude_none=True,
162140
methods=["DELETE"],
163-
endpoint=self._create_endpoint(
141+
endpoint=create_async_endpoint(
164142
self.client.delete_collection, CollectionUri
165143
),
166144
)

0 commit comments

Comments
 (0)