Skip to content

Commit c1aa6bd

Browse files
palfreysobolevn
andauthored
Adds DjangoFilterBackend option for filter_backends (#154)
* Support DjangoFilterBackend for filter_backends * Add django-filter requirement to setup.py * Redo base filter as protocol * Add standard BaseFilterBackend as an option * Try doing model type work on BaseFilterProtocol * Update generics.pyi * Add invariant type variable for BaseFilterProtocol * Add type variable to BaseFilterProtocol Co-authored-by: Nikita Sobolev <[email protected]>
1 parent 27533ac commit c1aa6bd

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

rest_framework-stubs/generics.pyi

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ from rest_framework.response import Response
1111
from rest_framework.serializers import BaseSerializer
1212

1313
_MT_co = TypeVar("_MT_co", bound=Model, covariant=True)
14+
_MT_inv = TypeVar("_MT_inv", bound=Model)
1415

1516
def get_object_or_404(
1617
queryset: Union[Type[_MT_co], Manager[_MT_co], QuerySet[_MT_co]], *filter_args: Any, **filter_kwargs: Any
@@ -19,12 +20,19 @@ def get_object_or_404(
1920
class UsesQuerySet(Protocol[_MT_co]):
2021
def get_queryset(self) -> QuerySet[_MT_co]: ...
2122

23+
# Can't just use BaseFilterBackend because there's also things like django_filters.rest_framework.DjangoFilterBackend that are
24+
# valid options but don't extend it
25+
class BaseFilterProtocol(Protocol[_MT_inv]):
26+
def filter_queryset(self, request: Request, queryset: QuerySet[_MT_inv], view: views.APIView) -> QuerySet[_MT_inv]: ...
27+
def get_schema_fields(self, view: views.APIView) -> List[Any]: ...
28+
def get_schema_operation_parameters(self, view: views.APIView): ...
29+
2230
class GenericAPIView(views.APIView, UsesQuerySet[_MT_co]):
2331
queryset: Optional[Union[QuerySet[_MT_co], Manager[_MT_co]]] = ...
2432
serializer_class: Optional[Type[BaseSerializer]] = ...
2533
lookup_field: str = ...
2634
lookup_url_kwarg: Optional[str] = ...
27-
filter_backends: Sequence[Type[BaseFilterBackend]] = ...
35+
filter_backends: Sequence[Union[Type[BaseFilterBackend], Type[BaseFilterProtocol[_MT_co]]]] = ...
2836
pagination_class: Optional[Type[BasePagination]] = ...
2937
def get_object(self) -> _MT_co: ...
3038
def get_serializer(self, *args: Any, **kwargs: Any) -> BaseSerializer[_MT_co]: ...

tests/typecheck/test_filters.yml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
- case: basic_filters
2+
main: |
3+
from rest_framework.mixins import CreateModelMixin
4+
from rest_framework.generics import GenericAPIView
5+
from rest_framework.serializers import BaseSerializer, ModelSerializer
6+
from rest_framework.filters import OrderingFilter
7+
from django.db.models import Model
8+
9+
class MyModel(Model):
10+
pass
11+
12+
class MyView(GenericAPIView):
13+
filter_backends = [OrderingFilter]
14+
15+
- case: django_filters
16+
main: |
17+
from rest_framework.mixins import CreateModelMixin
18+
from rest_framework.generics import GenericAPIView
19+
from rest_framework.views import APIView
20+
from rest_framework.serializers import BaseSerializer, ModelSerializer
21+
from django.db.models import Model
22+
23+
class MyModel(Model):
24+
pass
25+
26+
class MyFilterBackend:
27+
def filter_queryset(self, request, queryset, view):
28+
pass
29+
def get_schema_fields(self, view):
30+
pass
31+
def get_schema_operation_parameters(self, view):
32+
pass
33+
34+
class MyView(GenericAPIView):
35+
filter_backends = [MyFilterBackend]

0 commit comments

Comments
 (0)