Skip to content

Commit 80308f3

Browse files
refactor: improve validation logic to include any type that is equal to a validated value
1 parent 0e23f4a commit 80308f3

File tree

7 files changed

+100
-51
lines changed

7 files changed

+100
-51
lines changed

src/input/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ pub(crate) use return_enums::{
2626
EitherInt, EitherString, GenericIterator, Int, MaxLengthCheck, ValidationMatch,
2727
};
2828

29-
pub(crate) use shared::decimal_as_int;
30-
3129
// Defined here as it's not exported by pyo3
3230
pub fn py_error_on_minusone(py: Python<'_>, result: c_int) -> PyResult<()> {
3331
if result != -1 {

src/validators/decimal.rs

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::errors::ErrorType;
99
use crate::errors::ValResult;
1010
use crate::errors::{ErrorTypeDefaults, Number};
1111
use crate::errors::{ToErrorValue, ValError};
12-
use crate::input::{decimal_as_int, EitherInt, Input};
12+
use crate::input::Input;
1313
use crate::tools::SchemaDict;
1414

1515
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};
@@ -288,27 +288,3 @@ fn handle_decimal_new_error(input: impl ToErrorValue, error: PyErr, decimal_exce
288288
ValError::InternalErr(error)
289289
}
290290
}
291-
292-
pub(crate) fn try_from_decimal_to_int<'a, 'py, I: Input<'py> + ?Sized>(
293-
py: Python<'py>,
294-
input: &'a I,
295-
) -> ValResult<i64> {
296-
let Some(py_input) = input.as_python() else {
297-
return Err(ValError::new(ErrorTypeDefaults::DecimalType, input));
298-
};
299-
300-
if let Ok(false) = py_input.is_instance(get_decimal_type(py)) {
301-
return Err(ValError::new(ErrorTypeDefaults::DecimalType, input));
302-
}
303-
304-
let dec_value = match decimal_as_int(input, py_input)? {
305-
EitherInt::Py(value) => value,
306-
_ => return Err(ValError::new(ErrorType::DecimalParsing { context: None }, input)),
307-
};
308-
309-
let either_int = dec_value.exact_int()?;
310-
311-
let int = either_int.into_i64(py)?;
312-
313-
Ok(int)
314-
}

src/validators/enum_.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ impl EnumValidateValue for PlainEnumValidator {
169169
lookup: &LiteralLookup<PyObject>,
170170
strict: bool,
171171
) -> ValResult<Option<PyObject>> {
172-
match lookup.validate(py, input, strict)? {
172+
match lookup.validate(py, input)? {
173173
Some((_, v)) => Ok(Some(v.clone_ref(py))),
174174
None => {
175175
if !strict {
@@ -183,8 +183,14 @@ impl EnumValidateValue for PlainEnumValidator {
183183
} else if py_input.is_instance_of::<PyFloat>() {
184184
return Ok(lookup.validate_int(py, input, false)?.map(|v| v.clone_ref(py)));
185185
}
186+
if py_input.is_instance_of::<PyAny>() {
187+
if let Ok(Some(res)) = lookup.try_validate_any(input) {
188+
return Ok(Some(res.clone_ref(py)));
189+
}
190+
}
186191
}
187192
}
193+
188194
Ok(None)
189195
}
190196
}

src/validators/literal.rs

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ use crate::input::{Input, ValidationMatch};
1515
use crate::py_gc::PyGcTraverse;
1616
use crate::tools::SchemaDict;
1717

18-
use super::decimal::try_from_decimal_to_int;
1918
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};
2019

2120
#[derive(Debug, Clone, Default)]
@@ -105,7 +104,6 @@ impl<T: Debug> LiteralLookup<T> {
105104
&self,
106105
py: Python<'py>,
107106
input: &'a I,
108-
strict: bool,
109107
) -> ValResult<Option<(&'a I, &T)>> {
110108
if let Some(expected_bool) = &self.expected_bool {
111109
if let Ok(bool_value) = input.validate_bool(true) {
@@ -125,13 +123,6 @@ impl<T: Debug> LiteralLookup<T> {
125123
return Ok(Some((input, &self.values[*id])));
126124
}
127125
}
128-
// if the input is a Decimal type, we need to check if its value is in the expected_ints
129-
if let Ok(value) = try_from_decimal_to_int(py, input) {
130-
let Some(id) = expected_ints.get(&value) else {
131-
return Ok(None);
132-
};
133-
return Ok(Some((input, &self.values[*id])));
134-
}
135126
}
136127

137128
if let Some(expected_strings) = &self.expected_str {
@@ -152,15 +143,6 @@ impl<T: Debug> LiteralLookup<T> {
152143
return Ok(Some((input, &self.values[*id])));
153144
}
154145
}
155-
if !strict {
156-
// if the input is a Decimal type, we need to check if its value is in the expected_ints
157-
if let Ok(value) = try_from_decimal_to_int(py, input) {
158-
let Some(id) = expected_strings.get(&value.to_string()) else {
159-
return Ok(None);
160-
};
161-
return Ok(Some((input, &self.values[*id])));
162-
}
163-
}
164146
}
165147
if let Some(expected_py_dict) = &self.expected_py_dict {
166148
// We don't use ? to unpack the result of `get_item` in the next line because unhashable
@@ -236,6 +218,38 @@ impl<T: Debug> LiteralLookup<T> {
236218
}
237219
Ok(None)
238220
}
221+
222+
pub fn try_validate_any<'a, 'py, I: Input<'py> + ?Sized>(&self, input: &'a I) -> ValResult<Option<&T>> {
223+
let Some(py_input) = input.as_python() else {
224+
return Ok(None);
225+
};
226+
227+
if let Some(expected_ints) = &self.expected_int {
228+
for (k, id) in expected_ints {
229+
if let Ok(equality) = py_input.call_method1("__eq__", (*k,)) {
230+
if equality.extract::<bool>()? {
231+
return Ok(Some(&self.values[*id]));
232+
}
233+
};
234+
}
235+
};
236+
237+
let Some(expected_strings) = &self.expected_str else {
238+
return Ok(None);
239+
};
240+
241+
for (k, id) in expected_strings {
242+
let Ok(k_as_int) = k.parse::<i64>() else {
243+
continue;
244+
};
245+
if let Ok(equality) = py_input.call_method1("__eq__", (k_as_int,)) {
246+
if equality.extract::<bool>()? {
247+
return Ok(Some(&self.values[*id]));
248+
}
249+
};
250+
}
251+
Ok(None)
252+
}
239253
}
240254

241255
impl<T: PyGcTraverse + Debug> PyGcTraverse for LiteralLookup<T> {
@@ -289,7 +303,7 @@ impl Validator for LiteralValidator {
289303
input: &(impl Input<'py> + ?Sized),
290304
_state: &mut ValidationState<'_, 'py>,
291305
) -> ValResult<PyObject> {
292-
match self.lookup.validate(py, input, _state.strict_or(false))? {
306+
match self.lookup.validate(py, input)? {
293307
Some((_, v)) => Ok(v.clone()),
294308
None => Err(ValError::new(
295309
ErrorType::LiteralError {

src/validators/union.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ impl TaggedUnionValidator {
464464
input: &(impl Input<'py> + ?Sized),
465465
state: &mut ValidationState<'_, 'py>,
466466
) -> ValResult<PyObject> {
467-
if let Ok(Some((tag, validator))) = self.lookup.validate(py, tag, state.strict_or(false)) {
467+
if let Ok(Some((tag, validator))) = self.lookup.validate(py, tag) {
468468
return match validator.validate(py, input, state) {
469469
Ok(res) => Ok(res),
470470
Err(err) => Err(err.with_outer_location(tag)),

tests/validators/test_enums.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,32 @@ class MyIntEnum(IntEnum):
382382
assert v_int.validate_python(Decimal(float(value))) is MyIntEnum.VALUE
383383

384384

385+
def test_enum_int_validation_should_succeed_for_custom_type():
386+
# GIVEN
387+
class IntWrapper:
388+
def __init__(self, value):
389+
self.value = value
390+
391+
def __eq__(self, other: object) -> bool:
392+
return self.value == other
393+
394+
class MyEnum(Enum):
395+
VALUE = 999
396+
SECOND_VALUE = 1000000
397+
398+
# WHEN
399+
v = SchemaValidator(
400+
core_schema.with_default_schema(
401+
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
402+
default=MyEnum.VALUE,
403+
)
404+
)
405+
406+
# THEN
407+
assert v.validate_python(IntWrapper(999)) is MyEnum.VALUE
408+
assert v.validate_python(IntWrapper(1000000)) is MyEnum.SECOND_VALUE
409+
410+
385411
def test_enum_str_validation_should_succeed_for_decimal_with_strict_disabled():
386412
# GIVEN
387413
class MyEnum(Enum):

tests/validators/test_model.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,12 +1315,22 @@ class OtherModel:
13151315
]
13161316

13171317

1318-
def test_model_with_enum_int_field_validation_should_succeed_for_decimal():
1318+
def test_model_with_enum_int_field_validation_should_succeed_for_any_type_equality_checks():
1319+
# GIVEN
13191320
from enum import Enum
13201321

13211322
class EnumClass(Enum):
13221323
enum_value = 1
13231324
enum_value_2 = 2
1325+
enum_value_3 = 3
1326+
enum_value_4 = '4'
1327+
1328+
class IntWrappable:
1329+
def __init__(self, value: int):
1330+
self.value = value
1331+
1332+
def __eq__(self, value: object) -> bool:
1333+
return self.value == value
13241334

13251335
class MyModel:
13261336
__slots__ = (
@@ -1331,6 +1341,7 @@ class MyModel:
13311341
)
13321342
enum_field: EnumClass
13331343

1344+
# WHEN
13341345
v = SchemaValidator(
13351346
core_schema.model_schema(
13361347
MyModel,
@@ -1342,11 +1353,29 @@ class MyModel:
13421353
'enum_field_2': core_schema.model_field(
13431354
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
13441355
),
1356+
'enum_field_3': core_schema.model_field(
1357+
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
1358+
),
1359+
'enum_field_4': core_schema.model_field(
1360+
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
1361+
),
13451362
}
13461363
),
13471364
)
13481365
)
1349-
v.validate_json('{"enum_field": 1, "enum_field_2": 2}')
1350-
m = v.validate_python({'enum_field': Decimal(1), 'enum_field_2': Decimal(2)})
1366+
1367+
# THEN
1368+
v.validate_json('{"enum_field": 1, "enum_field_2": 2, "enum_field_3": 3, "enum_field_4": "4"}')
1369+
m = v.validate_python(
1370+
{
1371+
'enum_field': Decimal(1),
1372+
'enum_field_2': Decimal(2),
1373+
'enum_field_3': IntWrappable(3),
1374+
'enum_field_4': IntWrappable(4),
1375+
}
1376+
)
13511377
v.validate_assignment(m, 'enum_field', Decimal(1))
13521378
v.validate_assignment(m, 'enum_field_2', Decimal(2))
1379+
v.validate_assignment(m, 'enum_field_3', IntWrappable(3))
1380+
v.validate_assignment(m, 'enum_field_4', Decimal(4))
1381+
v.validate_assignment(m, 'enum_field_4', IntWrappable(4))

0 commit comments

Comments
 (0)