Skip to content

Commit 727deee

Browse files
Fix str subclass validation for enums (#1273)
Co-authored-by: David Hewitt <[email protected]>
1 parent b777774 commit 727deee

File tree

2 files changed

+67
-3
lines changed

2 files changed

+67
-3
lines changed

src/validators/enum_.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::marker::PhantomData;
44
use pyo3::exceptions::PyTypeError;
55
use pyo3::intern;
66
use pyo3::prelude::*;
7-
use pyo3::types::{PyDict, PyList, PyType};
7+
use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString, PyType};
88

99
use crate::build_tools::{is_strict, py_schema_err};
1010
use crate::errors::{ErrorType, ValError, ValResult};
@@ -167,9 +167,27 @@ impl EnumValidateValue for PlainEnumValidator {
167167
py: Python<'py>,
168168
input: &I,
169169
lookup: &LiteralLookup<PyObject>,
170-
_strict: bool,
170+
strict: bool,
171171
) -> ValResult<Option<PyObject>> {
172-
Ok(lookup.validate(py, input)?.map(|(_, v)| v.clone_ref(py)))
172+
match lookup.validate(py, input)? {
173+
Some((_, v)) => Ok(Some(v.clone_ref(py))),
174+
None => {
175+
if !strict {
176+
if let Some(py_input) = input.as_python() {
177+
// necessary for compatibility with 2.6, where str and int subclasses are allowed
178+
if py_input.is_instance_of::<PyString>() {
179+
return Ok(lookup.validate_str(input, false)?.map(|v| v.clone_ref(py)));
180+
} else if py_input.is_instance_of::<PyInt>() {
181+
return Ok(lookup.validate_int(py, input, false)?.map(|v| v.clone_ref(py)));
182+
// necessary for compatibility with 2.6, where float values are allowed for int enums in lax mode
183+
} else if py_input.is_instance_of::<PyFloat>() {
184+
return Ok(lookup.validate_int(py, input, false)?.map(|v| v.clone_ref(py)));
185+
}
186+
}
187+
}
188+
Ok(None)
189+
}
190+
}
173191
}
174192
}
175193

tests/validators/test_enums.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,52 @@ class MyEnum(Enum):
269269
SchemaValidator(core_schema.enum_schema(MyEnum, []))
270270

271271

272+
def test_enum_with_str_subclass() -> None:
273+
class MyEnum(Enum):
274+
a = 'a'
275+
b = 'b'
276+
277+
v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())))
278+
279+
assert v.validate_python(MyEnum.a) is MyEnum.a
280+
assert v.validate_python('a') is MyEnum.a
281+
282+
class MyStr(str):
283+
pass
284+
285+
assert v.validate_python(MyStr('a')) is MyEnum.a
286+
with pytest.raises(ValidationError):
287+
v.validate_python(MyStr('a'), strict=True)
288+
289+
290+
def test_enum_with_int_subclass() -> None:
291+
class MyEnum(Enum):
292+
a = 1
293+
b = 2
294+
295+
v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())))
296+
297+
assert v.validate_python(MyEnum.a) is MyEnum.a
298+
assert v.validate_python(1) is MyEnum.a
299+
300+
class MyInt(int):
301+
pass
302+
303+
assert v.validate_python(MyInt(1)) is MyEnum.a
304+
with pytest.raises(ValidationError):
305+
v.validate_python(MyInt(1), strict=True)
306+
307+
308+
def test_validate_float_for_int_enum() -> None:
309+
class MyEnum(int, Enum):
310+
a = 1
311+
b = 2
312+
313+
v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())))
314+
315+
assert v.validate_python(1.0) is MyEnum.a
316+
317+
272318
def test_missing_error_converted_to_val_error() -> None:
273319
class MyFlags(IntFlag):
274320
OFF = 0

0 commit comments

Comments
 (0)