Skip to content

Commit b11d3e8

Browse files
fix: make Decimal type work with StrEnum when strict mode is not enabled
1 parent b327afa commit b11d3e8

File tree

4 files changed

+116
-13
lines changed

4 files changed

+116
-13
lines changed

src/validators/enum_.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ impl EnumValidateValue for PlainEnumValidator {
169169
lookup: &LiteralLookup<PyObject>,
170170
strict: bool,
171171
) -> ValResult<Option<PyObject>> {
172-
match lookup.validate(py, input)? {
172+
match lookup.validate(py, input, strict)? {
173173
Some((_, v)) => Ok(Some(v.clone_ref(py))),
174174
None => {
175175
if !strict {

src/validators/literal.rs

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ impl<T: Debug> LiteralLookup<T> {
105105
&self,
106106
py: Python<'py>,
107107
input: &'a I,
108+
strict: bool,
108109
) -> ValResult<Option<(&'a I, &T)>> {
109110
if let Some(expected_bool) = &self.expected_bool {
110111
if let Ok(bool_value) = input.validate_bool(true) {
@@ -124,6 +125,10 @@ impl<T: Debug> LiteralLookup<T> {
124125
return Ok(Some((input, &self.values[*id])));
125126
}
126127
}
128+
// if the input is a Decimal type, we need to check if its value is in the expected_ints
129+
if let Ok(Some(v)) = self.try_from_dec_to_int(py, input, expected_ints) {
130+
return Ok(Some(v));
131+
}
127132
}
128133

129134
if let Some(expected_strings) = &self.expected_str {
@@ -144,6 +149,12 @@ impl<T: Debug> LiteralLookup<T> {
144149
return Ok(Some((input, &self.values[*id])));
145150
}
146151
}
152+
if !strict {
153+
// if the input is a Decimal type, we need to check if its value is in the expected_strings
154+
if let Ok(Some(v)) = self.try_from_dec_to_str(py, input, expected_strings) {
155+
return Ok(Some(v));
156+
}
157+
}
147158
}
148159
if let Some(expected_py_dict) = &self.expected_py_dict {
149160
// We don't use ? to unpack the result of `get_item` in the next line because unhashable
@@ -166,17 +177,14 @@ impl<T: Debug> LiteralLookup<T> {
166177
}
167178
};
168179

169-
// if the input is a Decimal type, we need to check if its value is in the expected_ints
170-
if let Ok(Some(v)) = self.validate_decimal(py, input) {
171-
return Ok(Some(v));
172-
}
173180
Ok(None)
174181
}
175182

176-
fn validate_decimal<'a, 'py, I: Input<'py> + ?Sized>(
183+
fn try_from_dec_to_int<'a, 'py, I: Input<'py> + ?Sized>(
177184
&self,
178185
py: Python<'py>,
179186
input: &'a I,
187+
expected_ints: &AHashMap<i64, usize>,
180188
) -> ValResult<Option<(&'a I, &T)>> {
181189
let Some(py_input) = input.as_python() else {
182190
return Ok(None);
@@ -186,10 +194,6 @@ impl<T: Debug> LiteralLookup<T> {
186194
return Ok(None);
187195
}
188196

189-
let Some(expected_ints) = &self.expected_int else {
190-
return Ok(None);
191-
};
192-
193197
let Ok(EitherInt::Py(dec_value)) = decimal_as_int(input, py_input) else {
194198
return Ok(None);
195199
};
@@ -202,9 +206,39 @@ impl<T: Debug> LiteralLookup<T> {
202206
let Some(id) = expected_ints.get(&int) else {
203207
return Ok(None);
204208
};
209+
205210
Ok(Some((input, &self.values[*id])))
206211
}
207212

213+
fn try_from_dec_to_str<'a, 'py, I: Input<'py> + ?Sized>(
214+
&self,
215+
py: Python<'py>,
216+
input: &'a I,
217+
expected_strings: &AHashMap<String, usize>,
218+
) -> ValResult<Option<(&'a I, &T)>> {
219+
let Some(py_input) = input.as_python() else {
220+
return Ok(None);
221+
};
222+
223+
if let Ok(false) = py_input.is_instance(get_decimal_type(py)) {
224+
return Ok(None);
225+
}
226+
227+
let Ok(EitherInt::Py(dec_value)) = decimal_as_int(input, py_input) else {
228+
return Ok(None);
229+
};
230+
231+
let Ok(either_int) = dec_value.exact_int() else {
232+
return Ok(None);
233+
};
234+
let int = either_int.into_i64(py)?;
235+
if let Some(id) = expected_strings.get(&int.to_string()) {
236+
return Ok(Some((input, &self.values[*id])));
237+
}
238+
239+
Ok(None)
240+
}
241+
208242
/// Used by int enums
209243
pub fn validate_int<'a, 'py, I: Input<'py> + ?Sized>(
210244
&self,
@@ -308,7 +342,7 @@ impl Validator for LiteralValidator {
308342
input: &(impl Input<'py> + ?Sized),
309343
_state: &mut ValidationState<'_, 'py>,
310344
) -> ValResult<PyObject> {
311-
match self.lookup.validate(py, input)? {
345+
match self.lookup.validate(py, input, _state.strict_or(false))? {
312346
Some((_, v)) => Ok(v.clone()),
313347
None => Err(ValError::new(
314348
ErrorType::LiteralError {

src/validators/union.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ impl TaggedUnionValidator {
464464
input: &(impl Input<'py> + ?Sized),
465465
state: &mut ValidationState<'_, 'py>,
466466
) -> ValResult<PyObject> {
467-
if let Ok(Some((tag, validator))) = self.lookup.validate(py, tag) {
467+
if let Ok(Some((tag, validator))) = self.lookup.validate(py, tag, state.strict_or(false)) {
468468
return match validator.validate(py, input, state) {
469469
Ok(res) => Ok(res),
470470
Err(err) => Err(err.with_outer_location(tag)),

tests/validators/test_enums.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import re
22
import sys
33
from decimal import Decimal
4-
from enum import Enum, IntEnum, IntFlag
4+
from enum import Enum, IntEnum, IntFlag, StrEnum
55

66
import pytest
77

@@ -352,31 +352,100 @@ class ColorEnum(IntEnum):
352352
[-1, 0, 1],
353353
)
354354
def test_enum_int_validation_should_succeed_for_decimal(value: int):
355+
# GIVEN
355356
class MyEnum(Enum):
356357
VALUE = value
357358

359+
class MyIntEnum(IntEnum):
360+
VALUE = value
361+
362+
# WHEN
358363
v = SchemaValidator(
359364
core_schema.with_default_schema(
360365
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
361366
default=MyEnum.VALUE,
362367
)
363368
)
369+
370+
v_int = SchemaValidator(
371+
core_schema.with_default_schema(
372+
schema=core_schema.enum_schema(MyIntEnum, list(MyIntEnum.__members__.values())),
373+
default=MyIntEnum.VALUE,
374+
)
375+
)
376+
377+
# THEN
364378
assert v.validate_python(Decimal(value)) is MyEnum.VALUE
365379
assert v.validate_python(Decimal(float(value))) is MyEnum.VALUE
366380

381+
assert v_int.validate_python(Decimal(value)) is MyIntEnum.VALUE
382+
assert v_int.validate_python(Decimal(float(value))) is MyIntEnum.VALUE
383+
384+
385+
def test_enum_str_validation_should_succeed_for_decimal_with_strict_disabled():
386+
# GIVEN
387+
class MyEnum(StrEnum):
388+
VALUE = '1'
389+
390+
# WHEN
391+
v = SchemaValidator(
392+
core_schema.with_default_schema(
393+
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
394+
default=MyEnum.VALUE,
395+
)
396+
)
397+
398+
# THEN
399+
assert v.validate_python(Decimal(1)) is MyEnum.VALUE
400+
401+
402+
def test_enum_str_validation_should_fail_for_decimal_with_strict_enabled():
403+
# GIVEN
404+
class MyEnum(StrEnum):
405+
VALUE = '1'
406+
407+
# WHEN
408+
v = SchemaValidator(
409+
core_schema.with_default_schema(
410+
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()), strict=True),
411+
default=MyEnum.VALUE,
412+
)
413+
)
414+
415+
# THEN
416+
with pytest.raises(ValidationError):
417+
v.validate_python(Decimal(1))
418+
367419

368420
def test_enum_int_validation_should_fail_for_incorrect_decimal_value():
421+
# GIVEN
369422
class MyEnum(Enum):
370423
VALUE = 1
371424

425+
class MyStrEnum(StrEnum):
426+
VALUE = '2'
427+
428+
# WHEN
372429
v = SchemaValidator(
373430
core_schema.with_default_schema(
374431
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
375432
default=MyEnum.VALUE,
376433
)
377434
)
435+
436+
v_str = SchemaValidator(
437+
core_schema.with_default_schema(
438+
schema=core_schema.enum_schema(MyStrEnum, list(MyStrEnum.__members__.values())),
439+
default=MyStrEnum.VALUE,
440+
)
441+
)
442+
443+
# THEN
378444
with pytest.raises(ValidationError):
379445
v.validate_python(Decimal(2))
380446

381447
with pytest.raises(ValidationError):
382448
v.validate_python((1, 2))
449+
450+
with pytest.raises(ValidationError):
451+
v_str.validate_python(Decimal(1))

0 commit comments

Comments
 (0)