Skip to content

Commit 6bfc32f

Browse files
committed
initial work
1 parent e0b4c94 commit 6bfc32f

File tree

3 files changed

+192
-27
lines changed

3 files changed

+192
-27
lines changed

src/validators/enum_.rs

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
// Validator for Enums, so named because "enum" is a reserved keyword in Rust.
22
use std::marker::PhantomData;
33

4-
use pyo3::exceptions::PyTypeError;
54
use pyo3::intern;
65
use pyo3::prelude::*;
76
use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString, PyType};
87

98
use crate::build_tools::{is_strict, py_schema_err};
109
use crate::errors::{ErrorType, ValError, ValResult};
1110
use crate::input::Input;
12-
use crate::tools::{safe_repr, SchemaDict};
11+
use crate::tools::SchemaDict;
1312

1413
use super::is_instance::class_repr;
1514
use super::literal::{expected_repr_name, LiteralLookup};
@@ -119,33 +118,11 @@ impl<T: EnumValidateValue> Validator for EnumValidator<T> {
119118
} else if let Some(v) = T::validate_value(py, input, &self.lookup, strict)? {
120119
state.floor_exactness(Exactness::Lax);
121120
return Ok(v);
122-
} else if let Some(ref missing) = self.missing {
121+
} else if let Ok(res) = class.as_unbound().call1(py, (input.as_python(),)) {
123122
state.floor_exactness(Exactness::Lax);
124-
let enum_value = missing.bind(py).call1((input.to_object(py),)).map_err(|_| {
125-
ValError::new(
126-
ErrorType::Enum {
127-
expected: self.expected_repr.clone(),
128-
context: None,
129-
},
130-
input,
131-
)
132-
})?;
133-
// check enum_value is an instance of the class like
134-
// https://github.com/python/cpython/blob/v3.12.2/Lib/enum.py#L1148
135-
if enum_value.is_instance(class)? {
136-
return Ok(enum_value.into());
137-
} else if !enum_value.is(&py.None()) {
138-
let type_error = PyTypeError::new_err(format!(
139-
"error in {}._missing_: returned {} instead of None or a valid member",
140-
class
141-
.name()
142-
.and_then(|name| name.extract::<String>())
143-
.unwrap_or_else(|_| "<Unknown>".into()),
144-
safe_repr(&enum_value)
145-
));
146-
return Err(type_error.into());
147-
}
123+
return Ok(res);
148124
}
125+
149126
Err(ValError::new(
150127
ErrorType::Enum {
151128
expected: self.expected_repr.clone(),

tests/validators/test_enums.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22
import sys
3+
from decimal import Decimal
34
from enum import Enum, IntEnum, IntFlag
45

56
import pytest
@@ -344,3 +345,130 @@ class ColorEnum(IntEnum):
344345

345346
assert v.validate_python(ColorEnum.GREEN) is ColorEnum.GREEN
346347
assert v.validate_python(1 << 63) is ColorEnum.GREEN
348+
349+
350+
@pytest.mark.parametrize(
351+
'value',
352+
[-1, 0, 1],
353+
)
354+
def test_enum_int_validation_should_succeed_for_decimal(value: int):
355+
# GIVEN
356+
class MyEnum(Enum):
357+
VALUE = value
358+
359+
class MyIntEnum(IntEnum):
360+
VALUE = value
361+
362+
# WHEN
363+
v = SchemaValidator(
364+
core_schema.with_default_schema(
365+
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
366+
default=MyEnum.VALUE,
367+
)
368+
)
369+
370+
v_int = SchemaValidator(
371+
core_schema.with_default_schema(
372+
schema=core_schema.enum_schema(MyIntEnum, list(MyIntEnum.__members__.values())),
373+
default=MyIntEnum.VALUE,
374+
)
375+
)
376+
377+
# THEN
378+
assert v.validate_python(Decimal(value)) is MyEnum.VALUE
379+
assert v.validate_python(Decimal(float(value))) is MyEnum.VALUE
380+
381+
assert v_int.validate_python(Decimal(value)) is MyIntEnum.VALUE
382+
assert v_int.validate_python(Decimal(float(value))) is MyIntEnum.VALUE
383+
384+
385+
def test_enum_int_validation_should_succeed_for_custom_type():
386+
# GIVEN
387+
class AnyWrapper:
388+
def __init__(self, value):
389+
self.value = value
390+
391+
def __eq__(self, other: object) -> bool:
392+
return self.value == other
393+
394+
class MyEnum(Enum):
395+
VALUE = 999
396+
SECOND_VALUE = 1000000
397+
THIRD_VALUE = 'Py03'
398+
399+
# WHEN
400+
v = SchemaValidator(
401+
core_schema.with_default_schema(
402+
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
403+
default=MyEnum.VALUE,
404+
)
405+
)
406+
407+
# THEN
408+
assert v.validate_python(AnyWrapper(999)) is MyEnum.VALUE
409+
assert v.validate_python(AnyWrapper(1000000)) is MyEnum.SECOND_VALUE
410+
assert v.validate_python(AnyWrapper('Py03')) is MyEnum.THIRD_VALUE
411+
412+
413+
def test_enum_str_validation_should_fail_for_decimal_when_expecting_str_value():
414+
# GIVEN
415+
class MyEnum(Enum):
416+
VALUE = '1'
417+
418+
# WHEN
419+
v = SchemaValidator(
420+
core_schema.with_default_schema(
421+
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
422+
default=MyEnum.VALUE,
423+
)
424+
)
425+
426+
# THEN
427+
with pytest.raises(ValidationError):
428+
v.validate_python(Decimal(1))
429+
430+
431+
def test_enum_int_validation_should_fail_for_incorrect_decimal_value():
432+
# GIVEN
433+
class MyEnum(Enum):
434+
VALUE = 1
435+
436+
# WHEN
437+
v = SchemaValidator(
438+
core_schema.with_default_schema(
439+
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
440+
default=MyEnum.VALUE,
441+
)
442+
)
443+
444+
# THEN
445+
with pytest.raises(ValidationError):
446+
v.validate_python(Decimal(2))
447+
448+
with pytest.raises(ValidationError):
449+
v.validate_python((1, 2))
450+
451+
with pytest.raises(ValidationError):
452+
v.validate_python(Decimal(1.1))
453+
454+
455+
def test_enum_int_validation_should_fail_for_plain_type_without_eq_checking():
456+
# GIVEN
457+
class MyEnum(Enum):
458+
VALUE = 1
459+
460+
class MyClass:
461+
def __init__(self, value):
462+
self.value = value
463+
464+
# WHEN
465+
v = SchemaValidator(
466+
core_schema.with_default_schema(
467+
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
468+
default=MyEnum.VALUE,
469+
)
470+
)
471+
472+
# THEN
473+
with pytest.raises(ValidationError):
474+
v.validate_python(MyClass(1))

tests/validators/test_model.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22
from copy import deepcopy
3+
from decimal import Decimal
34
from typing import Any, Callable, Dict, List, Set, Tuple
45

56
import pytest
@@ -1312,3 +1313,62 @@ class OtherModel:
13121313
'ctx': {'class_name': 'MyModel'},
13131314
}
13141315
]
1316+
1317+
1318+
def test_model_with_enum_int_field_validation_should_succeed_for_any_type_equality_checks():
1319+
# GIVEN
1320+
from enum import Enum
1321+
1322+
class EnumClass(Enum):
1323+
enum_value = 1
1324+
enum_value_2 = 2
1325+
enum_value_3 = 3
1326+
1327+
class IntWrappable:
1328+
def __init__(self, value: int):
1329+
self.value = value
1330+
1331+
def __eq__(self, value: object) -> bool:
1332+
return self.value == value
1333+
1334+
class MyModel:
1335+
__slots__ = (
1336+
'__dict__',
1337+
'__pydantic_fields_set__',
1338+
'__pydantic_extra__',
1339+
'__pydantic_private__',
1340+
)
1341+
enum_field: EnumClass
1342+
1343+
# WHEN
1344+
v = SchemaValidator(
1345+
core_schema.model_schema(
1346+
MyModel,
1347+
core_schema.model_fields_schema(
1348+
{
1349+
'enum_field': core_schema.model_field(
1350+
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
1351+
),
1352+
'enum_field_2': core_schema.model_field(
1353+
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
1354+
),
1355+
'enum_field_3': core_schema.model_field(
1356+
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
1357+
),
1358+
}
1359+
),
1360+
)
1361+
)
1362+
1363+
# THEN
1364+
v.validate_json('{"enum_field": 1, "enum_field_2": 2, "enum_field_3": 3}')
1365+
m = v.validate_python(
1366+
{
1367+
'enum_field': Decimal(1),
1368+
'enum_field_2': Decimal(2),
1369+
'enum_field_3': IntWrappable(3),
1370+
}
1371+
)
1372+
v.validate_assignment(m, 'enum_field', Decimal(1))
1373+
v.validate_assignment(m, 'enum_field_2', Decimal(2))
1374+
v.validate_assignment(m, 'enum_field_3', IntWrappable(3))

0 commit comments

Comments
 (0)