Skip to content

Commit 00b5346

Browse files
refactor: make is_equal_to function generic and move it out of try_validate_any
refactor: add more test coverage
1 parent a5e1d3d commit 00b5346

File tree

3 files changed

+78
-11
lines changed

3 files changed

+78
-11
lines changed

src/validators/literal.rs

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,10 @@ impl<T: Debug> LiteralLookup<T> {
224224
return Ok(None);
225225
};
226226

227-
let is_equal = |k: &i64| -> PyResult<bool> {
228-
let equality = py_input.call_method1("__eq__", (*k,))?;
229-
equality.extract::<bool>()
230-
};
231-
232227
if let Some(expected_ints) = &self.expected_int {
233-
let id = expected_ints.iter().find(|(k, _)| is_equal(k).unwrap_or(false));
228+
let id = expected_ints
229+
.iter()
230+
.find(|(&k, _)| is_equal_to(py_input, k).unwrap_or(false));
234231

235232
if let Some((_, id)) = id {
236233
return Ok(Some(&self.values[*id]));
@@ -241,10 +238,20 @@ impl<T: Debug> LiteralLookup<T> {
241238
return Ok(None);
242239
};
243240

241+
// try with raw strings
242+
let id = expected_strings
243+
.iter()
244+
.find(|(k, _)| is_equal_to(py_input, k.as_str()).unwrap_or(false));
245+
246+
if let Some((_, id)) = id {
247+
return Ok(Some(&self.values[*id]));
248+
}
249+
250+
// try with converting to int
244251
let id = expected_strings
245252
.iter()
246253
.filter_map(|(k, id)| k.parse::<i64>().ok().map(|k_as_int| (k_as_int, id)))
247-
.find(|(k, _)| is_equal(k).unwrap_or(false));
254+
.find(|(k, _)| is_equal_to(py_input, *k).unwrap_or(false));
248255

249256
if let Some((_, id)) = id {
250257
return Ok(Some(&self.values[*id]));
@@ -253,6 +260,11 @@ impl<T: Debug> LiteralLookup<T> {
253260
}
254261
}
255262

263+
fn is_equal_to<TValue: IntoPy<Py<PyAny>>>(input: &Bound<PyAny>, value: TValue) -> PyResult<bool> {
264+
let equality = input.call_method1("__eq__", (value,))?;
265+
equality.extract::<bool>()
266+
}
267+
256268
impl<T: PyGcTraverse + Debug> PyGcTraverse for LiteralLookup<T> {
257269
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
258270
self.expected_py_dict.py_gc_traverse(visit)?;

tests/validators/test_enums.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ class MyIntEnum(IntEnum):
384384

385385
def test_enum_int_validation_should_succeed_for_custom_type():
386386
# GIVEN
387-
class IntWrapper:
387+
class AnyWrapper:
388388
def __init__(self, value):
389389
self.value = value
390390

@@ -394,6 +394,7 @@ def __eq__(self, other: object) -> bool:
394394
class MyEnum(Enum):
395395
VALUE = 999
396396
SECOND_VALUE = 1000000
397+
THIRD_VALUE = 'Py03'
397398

398399
# WHEN
399400
v = SchemaValidator(
@@ -404,8 +405,9 @@ class MyEnum(Enum):
404405
)
405406

406407
# THEN
407-
assert v.validate_python(IntWrapper(999)) is MyEnum.VALUE
408-
assert v.validate_python(IntWrapper(1000000)) is MyEnum.SECOND_VALUE
408+
assert v.validate_python(AnyWrapper(999)) is MyEnum.VALUE
409+
assert v.validate_python(AnyWrapper(1000000)) is MyEnum.SECOND_VALUE
410+
assert v.validate_python(AnyWrapper('Py03')) is MyEnum.THIRD_VALUE
409411

410412

411413
def test_enum_str_validation_should_succeed_for_decimal_with_strict_disabled():
@@ -481,3 +483,25 @@ class MyStrEnum(Enum):
481483

482484
with pytest.raises(ValidationError):
483485
v_str.validate_python(Decimal(2.1))
486+
487+
488+
def test_enum_int_validation_should_fail_for_plain_type_without_eq_checking():
489+
# GIVEN
490+
class MyEnum(Enum):
491+
VALUE = 1
492+
493+
class MyClass:
494+
def __init__(self, value):
495+
self.value = value
496+
497+
# WHEN
498+
v = SchemaValidator(
499+
core_schema.with_default_schema(
500+
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
501+
default=MyEnum.VALUE,
502+
)
503+
)
504+
505+
# THEN
506+
with pytest.raises(ValidationError):
507+
v.validate_python(MyClass(1))

tests/validators/test_model.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import re
22
from copy import deepcopy
33
from decimal import Decimal
4-
from typing import Any, Callable, Dict, List, Set, Tuple
4+
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple
55

66
import pytest
77
from dirty_equals import HasRepr, IsInstance
@@ -1379,3 +1379,34 @@ class MyModel:
13791379
v.validate_assignment(m, 'enum_field_3', IntWrappable(3))
13801380
v.validate_assignment(m, 'enum_field_4', Decimal(4))
13811381
v.validate_assignment(m, 'enum_field_4', IntWrappable(4))
1382+
1383+
1384+
def test_model_bug():
1385+
class MyModel:
1386+
__slots__ = (
1387+
'__dict__',
1388+
'__pydantic_fields_set__',
1389+
'__pydantic_extra__',
1390+
'__pydantic_private__',
1391+
)
1392+
x: Iterable[int]
1393+
1394+
# WHEN
1395+
v = SchemaValidator(
1396+
core_schema.model_schema(
1397+
MyModel,
1398+
core_schema.model_fields_schema(
1399+
{
1400+
'x': core_schema.model_field(core_schema.generator_schema()),
1401+
},
1402+
),
1403+
)
1404+
)
1405+
print(v)
1406+
1407+
# THEN
1408+
# v.validate_json('{"x": [1, 2, 3]}')
1409+
m = v.validate_python({'x': [1, 2, 3]})
1410+
print(m)
1411+
print(m.x)
1412+
print(type(m.x))

0 commit comments

Comments
 (0)