Skip to content

Commit 7cbc640

Browse files
committed
fix #113
1 parent d3ec5a4 commit 7cbc640

File tree

8 files changed

+316
-236
lines changed

8 files changed

+316
-236
lines changed

src/django_enum/drf.py

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Support for django rest framework symmetric serialization"""
22

3-
__all__ = ["EnumField", "EnumFieldMixin"]
3+
__all__ = ["EnumField", "FlagField", "EnumFieldMixin"]
44

55
import inspect
66
from datetime import date, datetime, time, timedelta
77
from decimal import Decimal, DecimalException
8-
from enum import Enum
8+
from enum import Enum, Flag
9+
from functools import reduce
10+
from operator import or_
911
from typing import Any, Dict, Optional, Type, Union
1012

1113
from rest_framework.fields import (
@@ -23,7 +25,8 @@
2325
from rest_framework.serializers import ModelSerializer
2426
from rest_framework.utils.field_mapping import get_field_kwargs
2527

26-
from django_enum import EnumField as EnumModelField
28+
from django_enum.fields import EnumField as EnumModelField
29+
from django_enum.fields import FlagField as FlagModelField
2730
from django_enum.utils import (
2831
choices,
2932
decimal_params,
@@ -172,6 +175,80 @@ def to_representation(self, value: Any) -> Any:
172175
return getattr(value, "value", value)
173176

174177

178+
class FlagField(ChoiceField):
179+
"""
180+
A djangorestframework serializer field for :class:`~enum.Flag` types. If
181+
unspecified ModelSerializers will assign :class:`~django_enum.fields.FlagField`
182+
model field types to `ChoiceField
183+
<https://www.django-rest-framework.org/api-guide/fields/#choicefield>`_ which will
184+
not combine composite flag values appropriately. This field will also allow any
185+
symmetric values to be used (e.g. labels or names instead of values).
186+
187+
**You should add** :class:`~django_enum.drf.EnumFieldMixin` **to your serializer to
188+
automatically use this field.**
189+
190+
:param enum: The type of the flag of the field
191+
:param strict: If True (default) only values in the flag type
192+
will be acceptable. If False, no errors will be thrown if other
193+
values of the same primitive type are used
194+
:param kwargs: Any other named arguments applicable to a ChoiceField
195+
will be passed up to the base classes.
196+
"""
197+
198+
enum: Type[Flag]
199+
strict: bool = True
200+
201+
def __init__(self, enum: Type[Flag], strict: bool = strict, **kwargs):
202+
self.enum = enum
203+
self.strict = strict
204+
self.choices = kwargs.pop("choices", choices(enum))
205+
kwargs.pop("field_name", None)
206+
kwargs.pop("model_field", None)
207+
super().__init__(choices=self.choices, **kwargs)
208+
209+
def to_internal_value(self, data: Any) -> Union[Enum, Any]:
210+
"""
211+
Transform the *incoming* primitive data into an enum instance.
212+
We accept a composite flag value or a list of values. If a list,
213+
each element will be converted to a flag value and then the values
214+
will be reduced into a composite value with the or operator.
215+
216+
:return: A composite flag value.
217+
"""
218+
if not data:
219+
if self.allow_null and (data is None or data == ""):
220+
return None
221+
return self.enum(0)
222+
223+
if not isinstance(data, self.enum):
224+
try:
225+
return self.enum(data)
226+
except (TypeError, ValueError):
227+
try:
228+
if isinstance(data, str):
229+
return self.enum[data]
230+
if isinstance(data, (list, tuple)):
231+
values = []
232+
for val in data:
233+
try:
234+
values.append(self.enum(val))
235+
except (TypeError, ValueError):
236+
values.append(self.enum[val])
237+
return reduce(or_, values)
238+
except (TypeError, ValueError, KeyError):
239+
pass
240+
self.fail("invalid_choice", input=data)
241+
return data
242+
243+
def to_representation(self, value: Any) -> Any:
244+
"""
245+
Transform the *outgoing* enum value into its primitive value.
246+
247+
:return: The primitive composite value of the flag (most likely an integer).
248+
"""
249+
return getattr(value, "value", value)
250+
251+
175252
class EnumFieldMixin(with_typehint(ModelSerializer)): # type: ignore
176253
"""
177254
A mixin for ModelSerializers that adds auto-magic support for
@@ -204,7 +281,9 @@ class Meta:
204281
:return: A 2-tuple, the first element is the field class, the
205282
second is the kwargs for the field
206283
"""
207-
field_class = ClassLookupDict({EnumModelField: EnumField})[model_field]
284+
field_class = ClassLookupDict(
285+
{FlagModelField: FlagField, EnumModelField: EnumField}
286+
)[model_field]
208287
if field_class:
209288
return field_class, {
210289
"enum": model_field.enum,

src/django_enum/forms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
"NonStrictSelect",
3535
"NonStrictSelectMultiple",
3636
"FlagSelectMultiple",
37+
"FlagCheckbox",
38+
"NonStrictFlagSelectMultiple",
39+
"NonStrictFlagCheckbox",
3740
"NonStrictRadioSelect",
3841
"ChoiceFieldMixin",
3942
"EnumChoiceField",

tests/djenum/urls.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@
3434
try:
3535
from rest_framework import routers
3636

37-
from tests.djenum.views import DRFView
37+
from tests.djenum.views import DRFView, DRFFlagView
3838

3939
router = routers.DefaultRouter()
4040
router.register(r"enumtesters", DRFView)
41+
router.register(r"flagtesters", DRFFlagView)
4142
urlpatterns.append(path("drf/", include(router.urls)))
4243

4344
except ImportError: # pragma: no cover

tests/djenum/views.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,16 @@ class DRFView(viewsets.ModelViewSet):
143143
queryset = EnumTester.objects.all()
144144
serializer_class = EnumTesterSerializer
145145

146+
class FlagTesterSerializer(EnumFieldMixin, serializers.ModelSerializer):
147+
class Meta:
148+
model = FlagFilterTester
149+
fields = "__all__"
150+
151+
class DRFFlagView(viewsets.ModelViewSet):
152+
queryset = FlagFilterTester.objects.all()
153+
serializer_class = FlagTesterSerializer
154+
155+
146156
except (ImportError, ModuleNotFoundError): # pragma: no cover
147157
pass
148158

tests/enum_prop/urls.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@
3232
try:
3333
from rest_framework import routers
3434

35-
from tests.enum_prop.views import DRFView
35+
from tests.enum_prop.views import DRFView, DRFFlagView
3636

3737
router = routers.DefaultRouter()
3838
router.register(r"enumtesters", DRFView)
39+
router.register(r"flagtesters", DRFFlagView)
3940
urlpatterns.append(path("drf/", include(router.urls)))
4041

4142
except ImportError: # pragma: no cover

tests/enum_prop/views.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,15 @@ class DRFView(viewsets.ModelViewSet):
124124
queryset = EnumTester.objects.all()
125125
serializer_class = EnumTesterSerializer
126126

127+
class FlagTesterSerializer(EnumFieldMixin, serializers.ModelSerializer):
128+
class Meta:
129+
model = FlagFilterTester
130+
fields = "__all__"
131+
132+
class DRFFlagView(viewsets.ModelViewSet):
133+
queryset = FlagFilterTester.objects.all()
134+
serializer_class = FlagTesterSerializer
135+
127136
except (ImportError, ModuleNotFoundError): # pragma: no cover
128137
pass
129138

0 commit comments

Comments
 (0)