Skip to content

Commit 8c1a0da

Browse files
Support wider variety of enum validation cases (#1456)
1 parent e0b4c94 commit 8c1a0da

File tree

3 files changed

+216
-3
lines changed

3 files changed

+216
-3
lines changed

src/validators/enum_.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,15 @@ impl<T: EnumValidateValue> Validator for EnumValidator<T> {
116116
},
117117
input,
118118
));
119-
} else if let Some(v) = T::validate_value(py, input, &self.lookup, strict)? {
120-
state.floor_exactness(Exactness::Lax);
119+
}
120+
121+
state.floor_exactness(Exactness::Lax);
122+
123+
if let Some(v) = T::validate_value(py, input, &self.lookup, strict)? {
121124
return Ok(v);
125+
} else if let Ok(res) = class.as_unbound().call1(py, (input.as_python(),)) {
126+
return Ok(res);
122127
} else if let Some(ref missing) = self.missing {
123-
state.floor_exactness(Exactness::Lax);
124128
let enum_value = missing.bind(py).call1((input.to_object(py),)).map_err(|_| {
125129
ValError::new(
126130
ErrorType::Enum {
@@ -146,6 +150,7 @@ impl<T: EnumValidateValue> Validator for EnumValidator<T> {
146150
return Err(type_error.into());
147151
}
148152
}
153+
149154
Err(ValError::new(
150155
ErrorType::Enum {
151156
expected: self.expected_repr.clone(),

tests/validators/test_enums.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22
import sys
3+
from decimal import Decimal
34
from enum import Enum, IntEnum, IntFlag
45

56
import pytest
@@ -344,3 +345,145 @@ class ColorEnum(IntEnum):
344345

345346
assert v.validate_python(ColorEnum.GREEN) is ColorEnum.GREEN
346347
assert v.validate_python(1 << 63) is ColorEnum.GREEN
348+
349+
350+
@pytest.mark.parametrize(
351+
'value',
352+
[-1, 0, 1],
353+
)
354+
def test_enum_int_validation_should_succeed_for_decimal(value: int):
355+
class MyEnum(Enum):
356+
VALUE = value
357+
358+
class MyIntEnum(IntEnum):
359+
VALUE = value
360+
361+
v = SchemaValidator(
362+
core_schema.with_default_schema(
363+
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
364+
default=MyEnum.VALUE,
365+
)
366+
)
367+
368+
v_int = SchemaValidator(
369+
core_schema.with_default_schema(
370+
schema=core_schema.enum_schema(MyIntEnum, list(MyIntEnum.__members__.values())),
371+
default=MyIntEnum.VALUE,
372+
)
373+
)
374+
375+
assert v.validate_python(Decimal(value)) is MyEnum.VALUE
376+
assert v.validate_python(Decimal(float(value))) is MyEnum.VALUE
377+
assert v_int.validate_python(Decimal(value)) is MyIntEnum.VALUE
378+
assert v_int.validate_python(Decimal(float(value))) is MyIntEnum.VALUE
379+
380+
381+
@pytest.mark.skipif(
382+
sys.version_info >= (3, 13),
383+
reason='Python 3.13+ enum initialization is different, see https://github.com/python/cpython/blob/ec610069637d56101896803a70d418a89afe0b4b/Lib/enum.py#L1159-L1163',
384+
)
385+
def test_enum_int_validation_should_succeed_for_custom_type():
386+
class AnyWrapper:
387+
def __init__(self, value):
388+
self.value = value
389+
390+
def __eq__(self, other: object) -> bool:
391+
return self.value == other
392+
393+
class MyEnum(Enum):
394+
VALUE = 999
395+
SECOND_VALUE = 1000000
396+
THIRD_VALUE = 'Py03'
397+
398+
v = SchemaValidator(
399+
core_schema.with_default_schema(
400+
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
401+
default=MyEnum.VALUE,
402+
)
403+
)
404+
405+
assert v.validate_python(AnyWrapper(999)) is MyEnum.VALUE
406+
assert v.validate_python(AnyWrapper(1000000)) is MyEnum.SECOND_VALUE
407+
assert v.validate_python(AnyWrapper('Py03')) is MyEnum.THIRD_VALUE
408+
409+
410+
def test_enum_str_validation_should_fail_for_decimal_when_expecting_str_value():
411+
class MyEnum(Enum):
412+
VALUE = '1'
413+
414+
v = SchemaValidator(
415+
core_schema.with_default_schema(
416+
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
417+
default=MyEnum.VALUE,
418+
)
419+
)
420+
421+
with pytest.raises(ValidationError):
422+
v.validate_python(Decimal(1))
423+
424+
425+
def test_enum_int_validation_should_fail_for_incorrect_decimal_value():
426+
class MyEnum(Enum):
427+
VALUE = 1
428+
429+
v = SchemaValidator(
430+
core_schema.with_default_schema(
431+
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
432+
default=MyEnum.VALUE,
433+
)
434+
)
435+
436+
with pytest.raises(ValidationError):
437+
v.validate_python(Decimal(2))
438+
439+
with pytest.raises(ValidationError):
440+
v.validate_python((1, 2))
441+
442+
with pytest.raises(ValidationError):
443+
v.validate_python(Decimal(1.1))
444+
445+
446+
def test_enum_int_validation_should_fail_for_plain_type_without_eq_checking():
447+
class MyEnum(Enum):
448+
VALUE = 1
449+
450+
class MyClass:
451+
def __init__(self, value):
452+
self.value = value
453+
454+
v = SchemaValidator(
455+
core_schema.with_default_schema(
456+
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
457+
default=MyEnum.VALUE,
458+
)
459+
)
460+
461+
with pytest.raises(ValidationError):
462+
v.validate_python(MyClass(1))
463+
464+
465+
def support_custom_new_method() -> None:
466+
"""Demonstrates support for custom new methods, as well as conceptually, multi-value enums without dependency on a 3rd party lib for testing."""
467+
468+
class Animal(Enum):
469+
CAT = 'cat', 'meow'
470+
DOG = 'dog', 'woof'
471+
472+
def __new__(cls, species: str, sound: str):
473+
obj = object.__new__(cls)
474+
475+
obj._value_ = species
476+
obj._all_values = (species, sound)
477+
478+
obj.species = species
479+
obj.sound = sound
480+
481+
cls._value2member_map_[sound] = obj
482+
483+
return obj
484+
485+
v = SchemaValidator(core_schema.enum_schema(Animal, list(Animal.__members__.values())))
486+
assert v.validate_python('cat') is Animal.CAT
487+
assert v.validate_python('meow') is Animal.CAT
488+
assert v.validate_python('dog') is Animal.DOG
489+
assert v.validate_python('woof') is Animal.DOG

tests/validators/test_model.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import re
2+
import sys
23
from copy import deepcopy
4+
from decimal import Decimal
35
from typing import Any, Callable, Dict, List, Set, Tuple
46

57
import pytest
@@ -1312,3 +1314,66 @@ class OtherModel:
13121314
'ctx': {'class_name': 'MyModel'},
13131315
}
13141316
]
1317+
1318+
1319+
@pytest.mark.skipif(
1320+
sys.version_info >= (3, 13),
1321+
reason='Python 3.13+ enum initialization is different, see https://github.com/python/cpython/blob/ec610069637d56101896803a70d418a89afe0b4b/Lib/enum.py#L1159-L1163',
1322+
)
1323+
def test_model_with_enum_int_field_validation_should_succeed_for_any_type_equality_checks():
1324+
# GIVEN
1325+
from enum import Enum
1326+
1327+
class EnumClass(Enum):
1328+
enum_value = 1
1329+
enum_value_2 = 2
1330+
enum_value_3 = 3
1331+
1332+
class IntWrappable:
1333+
def __init__(self, value: int):
1334+
self.value = value
1335+
1336+
def __eq__(self, other: object) -> bool:
1337+
return self.value == other
1338+
1339+
class MyModel:
1340+
__slots__ = (
1341+
'__dict__',
1342+
'__pydantic_fields_set__',
1343+
'__pydantic_extra__',
1344+
'__pydantic_private__',
1345+
)
1346+
enum_field: EnumClass
1347+
1348+
# WHEN
1349+
v = SchemaValidator(
1350+
core_schema.model_schema(
1351+
MyModel,
1352+
core_schema.model_fields_schema(
1353+
{
1354+
'enum_field': core_schema.model_field(
1355+
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
1356+
),
1357+
'enum_field_2': core_schema.model_field(
1358+
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
1359+
),
1360+
'enum_field_3': core_schema.model_field(
1361+
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
1362+
),
1363+
}
1364+
),
1365+
)
1366+
)
1367+
1368+
# THEN
1369+
v.validate_json('{"enum_field": 1, "enum_field_2": 2, "enum_field_3": 3}')
1370+
m = v.validate_python(
1371+
{
1372+
'enum_field': Decimal(1),
1373+
'enum_field_2': Decimal(2),
1374+
'enum_field_3': IntWrappable(3),
1375+
}
1376+
)
1377+
v.validate_assignment(m, 'enum_field', Decimal(1))
1378+
v.validate_assignment(m, 'enum_field_2', Decimal(2))
1379+
v.validate_assignment(m, 'enum_field_3', IntWrappable(3))

0 commit comments

Comments
 (0)