Skip to content

Commit 5d6986c

Browse files
fix: refactoring based on code review comment
removed try_validate_any function and instead try to create Python enum class. Test case modifications and fixes.
1 parent a6c2a30 commit 5d6986c

File tree

6 files changed

+11
-91
lines changed

6 files changed

+11
-91
lines changed

src/validators/enum_.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ pub trait EnumValidateValue: std::fmt::Debug + Clone + Send + Sync {
7979
py: Python<'py>,
8080
input: &I,
8181
lookup: &LiteralLookup<PyObject>,
82+
class: &Py<PyType>,
8283
strict: bool,
8384
) -> ValResult<Option<PyObject>>;
8485
}
@@ -116,7 +117,7 @@ impl<T: EnumValidateValue> Validator for EnumValidator<T> {
116117
},
117118
input,
118119
));
119-
} else if let Some(v) = T::validate_value(py, input, &self.lookup, strict)? {
120+
} else if let Some(v) = T::validate_value(py, input, &self.lookup, &self.class, strict)? {
120121
state.floor_exactness(Exactness::Lax);
121122
return Ok(v);
122123
} else if let Some(ref missing) = self.missing {
@@ -167,6 +168,7 @@ impl EnumValidateValue for PlainEnumValidator {
167168
py: Python<'py>,
168169
input: &I,
169170
lookup: &LiteralLookup<PyObject>,
171+
class: &Py<PyType>,
170172
strict: bool,
171173
) -> ValResult<Option<PyObject>> {
172174
match lookup.validate(py, input)? {
@@ -184,8 +186,8 @@ impl EnumValidateValue for PlainEnumValidator {
184186
return Ok(lookup.validate_int(py, input, false)?.map(|v| v.clone_ref(py)));
185187
}
186188
if py_input.is_instance_of::<PyAny>() {
187-
if let Ok(Some(res)) = lookup.try_validate_any(input) {
188-
return Ok(Some(res.clone_ref(py)));
189+
if let Ok(res) = class.call1(py, (py_input,)) {
190+
return Ok(Some(res));
189191
}
190192
}
191193
}
@@ -207,6 +209,7 @@ impl EnumValidateValue for IntEnumValidator {
207209
py: Python<'py>,
208210
input: &I,
209211
lookup: &LiteralLookup<PyObject>,
212+
_class: &Py<PyType>,
210213
strict: bool,
211214
) -> ValResult<Option<PyObject>> {
212215
Ok(lookup.validate_int(py, input, strict)?.map(|v| v.clone_ref(py)))
@@ -223,6 +226,7 @@ impl EnumValidateValue for StrEnumValidator {
223226
py: Python,
224227
input: &I,
225228
lookup: &LiteralLookup<PyObject>,
229+
_class: &Py<PyType>,
226230
strict: bool,
227231
) -> ValResult<Option<PyObject>> {
228232
Ok(lookup.validate_str(input, strict)?.map(|v| v.clone_ref(py)))
@@ -239,6 +243,7 @@ impl EnumValidateValue for FloatEnumValidator {
239243
py: Python<'py>,
240244
input: &I,
241245
lookup: &LiteralLookup<PyObject>,
246+
_class: &Py<PyType>,
242247
strict: bool,
243248
) -> ValResult<Option<PyObject>> {
244249
Ok(lookup.validate_float(py, input, strict)?.map(|v| v.clone_ref(py)))

src/validators/generator.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ impl Validator for GeneratorValidator {
8686
hide_input_in_errors: self.hide_input_in_errors,
8787
validation_error_cause: self.validation_error_cause,
8888
};
89+
8990
Ok(v_iterator.into_py(py))
9091
}
9192

src/validators/literal.rs

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -218,51 +218,6 @@ impl<T: Debug> LiteralLookup<T> {
218218
}
219219
Ok(None)
220220
}
221-
222-
pub fn try_validate_any<'a, 'py, I: Input<'py> + ?Sized>(&self, input: &'a I) -> ValResult<Option<&T>> {
223-
let Some(py_input) = input.as_python() else {
224-
return Ok(None);
225-
};
226-
227-
if let Some(expected_ints) = &self.expected_int {
228-
let id = expected_ints
229-
.iter()
230-
.find(|(&k, _)| is_equal_to(py_input, k).unwrap_or(false));
231-
232-
if let Some((_, id)) = id {
233-
return Ok(Some(&self.values[*id]));
234-
}
235-
};
236-
237-
let Some(expected_strings) = &self.expected_str else {
238-
return Ok(None);
239-
};
240-
241-
// try with raw strings
242-
let id = expected_strings
243-
.iter()
244-
.find(|(k, _)| is_equal_to(py_input, k.as_str()).unwrap_or(false));
245-
246-
if let Some((_, id)) = id {
247-
return Ok(Some(&self.values[*id]));
248-
}
249-
250-
// try with converting to int
251-
let id = expected_strings
252-
.iter()
253-
.filter_map(|(k, id)| k.parse::<i64>().ok().map(|k_as_int| (k_as_int, id)))
254-
.find(|(k, _)| is_equal_to(py_input, *k).unwrap_or(false));
255-
256-
if let Some((_, id)) = id {
257-
return Ok(Some(&self.values[*id]));
258-
}
259-
Ok(None)
260-
}
261-
}
262-
263-
fn is_equal_to<TValue: IntoPy<Py<PyAny>>>(input: &Bound<PyAny>, value: TValue) -> PyResult<bool> {
264-
let equality = input.call_method1("__eq__", (value,))?;
265-
equality.extract::<bool>()
266221
}
267222

268223
impl<T: PyGcTraverse + Debug> PyGcTraverse for LiteralLookup<T> {

src/validators/model.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ impl ModelValidator {
270270
.map_err(|e| convert_err(py, e, input));
271271
}
272272
}
273-
274273
let output = self.validator.validate(py, input, state)?;
275274

276275
let instance = create_class(self.class.bind(py))?;

tests/validators/test_enums.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ class MyEnum(Enum):
410410
assert v.validate_python(AnyWrapper('Py03')) is MyEnum.THIRD_VALUE
411411

412412

413-
def test_enum_str_validation_should_succeed_for_decimal_with_strict_disabled():
413+
def test_enum_str_validation_should_fail_for_decimal_when_expecting_str_value():
414414
# GIVEN
415415
class MyEnum(Enum):
416416
VALUE = '1'
@@ -423,23 +423,6 @@ class MyEnum(Enum):
423423
)
424424
)
425425

426-
# THEN
427-
assert v.validate_python(Decimal(1)) is MyEnum.VALUE
428-
429-
430-
def test_enum_str_validation_should_fail_for_decimal_with_strict_enabled():
431-
# GIVEN
432-
class MyEnum(Enum):
433-
VALUE = '1'
434-
435-
# WHEN
436-
v = SchemaValidator(
437-
core_schema.with_default_schema(
438-
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()), strict=True),
439-
default=MyEnum.VALUE,
440-
)
441-
)
442-
443426
# THEN
444427
with pytest.raises(ValidationError):
445428
v.validate_python(Decimal(1))
@@ -450,9 +433,6 @@ def test_enum_int_validation_should_fail_for_incorrect_decimal_value():
450433
class MyEnum(Enum):
451434
VALUE = 1
452435

453-
class MyStrEnum(Enum):
454-
VALUE = '2'
455-
456436
# WHEN
457437
v = SchemaValidator(
458438
core_schema.with_default_schema(
@@ -461,13 +441,6 @@ class MyStrEnum(Enum):
461441
)
462442
)
463443

464-
v_str = SchemaValidator(
465-
core_schema.with_default_schema(
466-
schema=core_schema.enum_schema(MyStrEnum, list(MyStrEnum.__members__.values())),
467-
default=MyStrEnum.VALUE,
468-
)
469-
)
470-
471444
# THEN
472445
with pytest.raises(ValidationError):
473446
v.validate_python(Decimal(2))
@@ -478,12 +451,6 @@ class MyStrEnum(Enum):
478451
with pytest.raises(ValidationError):
479452
v.validate_python(Decimal(1.1))
480453

481-
with pytest.raises(ValidationError):
482-
v_str.validate_python(Decimal(1))
483-
484-
with pytest.raises(ValidationError):
485-
v_str.validate_python(Decimal(2.1))
486-
487454

488455
def test_enum_int_validation_should_fail_for_plain_type_without_eq_checking():
489456
# GIVEN

tests/validators/test_model.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,7 +1323,6 @@ class EnumClass(Enum):
13231323
enum_value = 1
13241324
enum_value_2 = 2
13251325
enum_value_3 = 3
1326-
enum_value_4 = '4'
13271326

13281327
class IntWrappable:
13291328
def __init__(self, value: int):
@@ -1356,26 +1355,20 @@ class MyModel:
13561355
'enum_field_3': core_schema.model_field(
13571356
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
13581357
),
1359-
'enum_field_4': core_schema.model_field(
1360-
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
1361-
),
13621358
}
13631359
),
13641360
)
13651361
)
13661362

13671363
# THEN
1368-
v.validate_json('{"enum_field": 1, "enum_field_2": 2, "enum_field_3": 3, "enum_field_4": "4"}')
1364+
v.validate_json('{"enum_field": 1, "enum_field_2": 2, "enum_field_3": 3}')
13691365
m = v.validate_python(
13701366
{
13711367
'enum_field': Decimal(1),
13721368
'enum_field_2': Decimal(2),
13731369
'enum_field_3': IntWrappable(3),
1374-
'enum_field_4': IntWrappable(4),
13751370
}
13761371
)
13771372
v.validate_assignment(m, 'enum_field', Decimal(1))
13781373
v.validate_assignment(m, 'enum_field_2', Decimal(2))
13791374
v.validate_assignment(m, 'enum_field_3', IntWrappable(3))
1380-
v.validate_assignment(m, 'enum_field_4', Decimal(4))
1381-
v.validate_assignment(m, 'enum_field_4', IntWrappable(4))

0 commit comments

Comments
 (0)