Skip to content

refactoring to create dynamic request model #213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions stac_fastapi/api/stac_fastapi/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fastapi.openapi.utils import get_openapi
from pydantic import BaseModel
from stac_pydantic import Collection, Item, ItemCollection
from stac_pydantic.api import ConformanceClasses, LandingPage, Search
from stac_pydantic.api import ConformanceClasses, LandingPage
from stac_pydantic.api.collections import Collections
from stac_pydantic.version import STAC_VERSION
from starlette.responses import JSONResponse, Response
Expand All @@ -19,17 +19,16 @@
EmptyRequest,
ItemCollectionUri,
ItemUri,
SearchGetRequest,
_create_request_model,
create_request_model,
)
from stac_fastapi.api.routes import create_async_endpoint, create_sync_endpoint

# TODO: make this module not depend on `stac_fastapi.extensions`
from stac_fastapi.extensions.core import FieldsExtension
from stac_fastapi.extensions.core import FieldsExtension, TokenPaginationExtension
from stac_fastapi.types.config import ApiSettings, Settings
from stac_fastapi.types.core import AsyncBaseCoreClient, BaseCoreClient
from stac_fastapi.types.extension import ApiExtension
from stac_fastapi.types.search import STACSearch
from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest


@attr.s
Expand Down Expand Up @@ -72,9 +71,13 @@ class StacApi:
api_version: str = attr.ib(default="0.1")
stac_version: str = attr.ib(default=STAC_VERSION)
description: str = attr.ib(default="stac-fastapi")
search_request_model: Type[Search] = attr.ib(default=STACSearch)
search_get_request: Type[SearchGetRequest] = attr.ib(default=SearchGetRequest)
item_collection_uri: Type[ItemCollectionUri] = attr.ib(default=ItemCollectionUri)
search_get_request_model: Type[BaseSearchGetRequest] = attr.ib(
default=BaseSearchGetRequest
)
search_post_request_model: Type[BaseSearchPostRequest] = attr.ib(
default=BaseSearchPostRequest
)
pagination_extension = attr.ib(default=TokenPaginationExtension)
response_class: Type[Response] = attr.ib(default=JSONResponse)
middlewares: List = attr.ib(default=attr.Factory(lambda: [BrotliMiddleware]))

Expand Down Expand Up @@ -167,7 +170,6 @@ def register_post_search(self):
Returns:
None
"""
search_request_model = _create_request_model(self.search_request_model)
fields_ext = self.get_extension(FieldsExtension)
self.router.add_api_route(
name="Search",
Expand All @@ -180,7 +182,7 @@ def register_post_search(self):
response_model_exclude_none=True,
methods=["POST"],
endpoint=self._create_endpoint(
self.client.post_search, search_request_model
self.client.post_search, self.search_post_request_model
),
)

Expand All @@ -202,7 +204,7 @@ def register_get_search(self):
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(
self.client.get_search, self.search_get_request
self.client.get_search, self.search_get_request_model
),
)

Expand Down Expand Up @@ -248,6 +250,12 @@ def register_get_item_collection(self):
Returns:
None
"""
get_pagination_model = self.get_extension(self.pagination_extension).GET
request_model = create_request_model(
"ItemCollectionURI",
base_model=ItemCollectionUri,
mixins=[get_pagination_model],
)
self.router.add_api_route(
name="Get ItemCollection",
path="/collections/{collectionId}/items",
Expand All @@ -258,9 +266,7 @@ def register_get_item_collection(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(
self.client.item_collection, self.item_collection_uri
),
endpoint=self._create_endpoint(self.client.item_collection, request_model),
)

def register_core(self):
Expand Down
195 changes: 108 additions & 87 deletions stac_fastapi/api/stac_fastapi/api/models.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,99 @@
"""api request/response models."""

import abc
from typing import Dict, Optional, Type, Union
from typing import Optional, Type, Union

import attr
from fastapi import Body, Path
from pydantic import BaseModel, create_model
from pydantic.fields import UndefinedType


def _create_request_model(model: Type[BaseModel]) -> Type[BaseModel]:
from stac_fastapi.types.extension import ApiExtension
from stac_fastapi.types.search import (
APIRequest,
BaseSearchGetRequest,
BaseSearchPostRequest,
)


def create_request_model(
model_name="SearchGetRequest",
base_model: Union[Type[BaseModel], APIRequest] = BaseSearchGetRequest,
extensions: Optional[ApiExtension] = None,
mixins: Optional[Union[BaseModel, APIRequest]] = None,
request_type: Optional[str] = "GET",
) -> Union[Type[BaseModel], APIRequest]:
"""Create a pydantic model for validating request bodies."""
fields = {}
for (k, v) in model.__fields__.items():
# TODO: Filter out fields based on which extensions are present
field_info = v.field_info
body = Body(
None
if isinstance(field_info.default, UndefinedType)
else field_info.default,
default_factory=field_info.default_factory,
alias=field_info.alias,
alias_priority=field_info.alias_priority,
title=field_info.title,
description=field_info.description,
const=field_info.const,
gt=field_info.gt,
ge=field_info.ge,
lt=field_info.lt,
le=field_info.le,
multiple_of=field_info.multiple_of,
min_items=field_info.min_items,
max_items=field_info.max_items,
min_length=field_info.min_length,
max_length=field_info.max_length,
regex=field_info.regex,
extra=field_info.extra,
)
fields[k] = (v.outer_type_, body)
return create_model(model.__name__, **fields, __base__=model)


@attr.s # type:ignore
class APIRequest(abc.ABC):
"""Generic API Request base class."""

@abc.abstractmethod
def kwargs(self) -> Dict:
"""Transform api request params into format which matches the signature of the endpoint."""
...
extension_models = []

# Check extensions for additional parameters to search
for extension in extensions or []:
if extension_model := extension.get_request_model(request_type):
extension_models.append(extension_model)

mixins = mixins or []

models = [base_model] + extension_models + mixins

# Handle GET requests
if all([issubclass(m, APIRequest) for m in models]):
return attr.make_class(model_name, attrs={}, bases=tuple(models))

# Handle POST requests
elif all([issubclass(m, BaseModel) for m in models]):
for model in models:
for (k, v) in model.__fields__.items():
field_info = v.field_info
body = Body(
None
if isinstance(field_info.default, UndefinedType)
else field_info.default,
default_factory=field_info.default_factory,
alias=field_info.alias,
alias_priority=field_info.alias_priority,
title=field_info.title,
description=field_info.description,
const=field_info.const,
gt=field_info.gt,
ge=field_info.ge,
lt=field_info.lt,
le=field_info.le,
multiple_of=field_info.multiple_of,
min_items=field_info.min_items,
max_items=field_info.max_items,
min_length=field_info.min_length,
max_length=field_info.max_length,
regex=field_info.regex,
extra=field_info.extra,
)
fields[k] = (v.outer_type_, body)
return create_model(model_name, **fields, __base__=base_model)

raise TypeError("Mixed Request Model types. Check extension request types.")


def create_get_request_model(
extensions, base_model: BaseSearchGetRequest = BaseSearchGetRequest
):
"""Wrap create_request_model to create the GET request model."""
return create_request_model(
"SearchGetRequest",
base_model=BaseSearchGetRequest,
extensions=extensions,
request_type="GET",
)


def create_post_request_model(
extensions, base_model: BaseSearchPostRequest = BaseSearchGetRequest
):
"""Wrap create_request_model to create the POST request model."""
return create_request_model(
"SearchPostRequest",
base_model=BaseSearchPostRequest,
extensions=extensions,
request_type="POST",
)


@attr.s # type:ignore
Expand All @@ -57,73 +102,49 @@ class CollectionUri(APIRequest):

collectionId: str = attr.ib(default=Path(..., description="Collection ID"))

def kwargs(self) -> Dict:
"""kwargs."""
return {"id": self.collectionId}


@attr.s
class ItemUri(CollectionUri):
"""Delete item."""

itemId: str = attr.ib(default=Path(..., description="Item ID"))

def kwargs(self) -> Dict:
"""kwargs."""
return {"collection_id": self.collectionId, "item_id": self.itemId}


@attr.s
class EmptyRequest(APIRequest):
"""Empty request."""

def kwargs(self) -> Dict:
"""kwargs."""
return {}
...


@attr.s
class ItemCollectionUri(CollectionUri):
"""Get item collection."""

limit: int = attr.ib(default=10)
token: str = attr.ib(default=None)

def kwargs(self) -> Dict:
"""kwargs."""
return {
"id": self.collectionId,
"limit": self.limit,
"token": self.token,
}

class POSTTokenPagination(BaseModel):
"""Token pagination model for POST requests."""

token: Optional[str] = None


@attr.s
class SearchGetRequest(APIRequest):
"""GET search request."""

collections: Optional[str] = attr.ib(default=None)
ids: Optional[str] = attr.ib(default=None)
bbox: Optional[str] = attr.ib(default=None)
datetime: Optional[Union[str]] = attr.ib(default=None)
limit: Optional[int] = attr.ib(default=10)
query: Optional[str] = attr.ib(default=None)
class GETTokenPagination(APIRequest):
"""Token pagination for GET requests."""

token: Optional[str] = attr.ib(default=None)
fields: Optional[str] = attr.ib(default=None)
sortby: Optional[str] = attr.ib(default=None)

def kwargs(self) -> Dict:
"""kwargs."""
return {
"collections": self.collections.split(",")
if self.collections
else self.collections,
"ids": self.ids.split(",") if self.ids else self.ids,
"bbox": self.bbox.split(",") if self.bbox else self.bbox,
"datetime": self.datetime,
"limit": self.limit,
"query": self.query,
"token": self.token,
"fields": self.fields.split(",") if self.fields else self.fields,
"sortby": self.sortby.split(",") if self.sortby else self.sortby,
}


class POSTPagination(BaseModel):
"""Page based pagination for POST requests."""

page: Optional[str] = None


@attr.s
class GETPagination(APIRequest):
"""Page based pagination for GET requests."""

page: Optional[str] = attr.ib(default=None)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .context import ContextExtension
from .fields import FieldsExtension
from .filter import FilterExtension
from .pagination import PaginationExtension, TokenPaginationExtension
from .query import QueryExtension
from .sort import SortExtension
from .transaction import TransactionExtension
Expand All @@ -12,8 +13,10 @@
"ContextExtension",
"FieldsExtension",
"FilterExtension",
"PaginationExtension",
"QueryExtension",
"SortExtension",
"TilesExtension",
"TokenPaginationExtension",
"TransactionExtension",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Fields extension module."""


from .fields import FieldsExtension

__all__ = ["FieldsExtension"]
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from stac_fastapi.types.extension import ApiExtension

from .request import FieldsExtensionGetRequest, FieldsExtensionPostRequest


@attr.s
class FieldsExtension(ApiExtension):
Expand All @@ -24,10 +26,12 @@ class FieldsExtension(ApiExtension):

"""

GET = FieldsExtensionGetRequest
POST = FieldsExtensionPostRequest

conformance_classes: List[str] = attr.ib(
factory=lambda: ["https://api.stacspec.org/v1.0.0-beta.3/item-search/#fields"]
)
schema_href: Optional[str] = attr.ib(default=None)
default_includes: Set[str] = attr.ib(
factory=lambda: {
"id",
Expand All @@ -41,6 +45,7 @@ class FieldsExtension(ApiExtension):
"collection",
}
)
schema_href: Optional[str] = attr.ib(default=None)

def register(self, app: FastAPI) -> None:
"""Register the extension with a FastAPI application.
Expand Down
Loading