Skip to content

Support wider variety of enum validation cases #1456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/validators/enum_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,15 @@ impl<T: EnumValidateValue> Validator for EnumValidator<T> {
},
input,
));
} else if let Some(v) = T::validate_value(py, input, &self.lookup, strict)? {
state.floor_exactness(Exactness::Lax);
}

state.floor_exactness(Exactness::Lax);

if let Some(v) = T::validate_value(py, input, &self.lookup, strict)? {
return Ok(v);
} else if let Ok(res) = class.as_unbound().call1(py, (input.as_python(),)) {
return Ok(res);
} else if let Some(ref missing) = self.missing {
state.floor_exactness(Exactness::Lax);
let enum_value = missing.bind(py).call1((input.to_object(py),)).map_err(|_| {
ValError::new(
ErrorType::Enum {
Expand All @@ -146,6 +150,7 @@ impl<T: EnumValidateValue> Validator for EnumValidator<T> {
return Err(type_error.into());
}
}

Err(ValError::new(
ErrorType::Enum {
expected: self.expected_repr.clone(),
Expand Down
143 changes: 143 additions & 0 deletions tests/validators/test_enums.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import sys
from decimal import Decimal
from enum import Enum, IntEnum, IntFlag

import pytest
Expand Down Expand Up @@ -344,3 +345,145 @@ class ColorEnum(IntEnum):

assert v.validate_python(ColorEnum.GREEN) is ColorEnum.GREEN
assert v.validate_python(1 << 63) is ColorEnum.GREEN


@pytest.mark.parametrize(
'value',
[-1, 0, 1],
)
def test_enum_int_validation_should_succeed_for_decimal(value: int):
class MyEnum(Enum):
VALUE = value

class MyIntEnum(IntEnum):
VALUE = value

v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

v_int = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyIntEnum, list(MyIntEnum.__members__.values())),
default=MyIntEnum.VALUE,
)
)

assert v.validate_python(Decimal(value)) is MyEnum.VALUE
assert v.validate_python(Decimal(float(value))) is MyEnum.VALUE
assert v_int.validate_python(Decimal(value)) is MyIntEnum.VALUE
assert v_int.validate_python(Decimal(float(value))) is MyIntEnum.VALUE


@pytest.mark.skipif(
sys.version_info >= (3, 13),
reason='Python 3.13+ enum initialization is different, see https://github.com/python/cpython/blob/ec610069637d56101896803a70d418a89afe0b4b/Lib/enum.py#L1159-L1163',
)
def test_enum_int_validation_should_succeed_for_custom_type():
class AnyWrapper:
def __init__(self, value):
self.value = value

def __eq__(self, other: object) -> bool:
return self.value == other

class MyEnum(Enum):
VALUE = 999
SECOND_VALUE = 1000000
THIRD_VALUE = 'Py03'

v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

assert v.validate_python(AnyWrapper(999)) is MyEnum.VALUE
assert v.validate_python(AnyWrapper(1000000)) is MyEnum.SECOND_VALUE
assert v.validate_python(AnyWrapper('Py03')) is MyEnum.THIRD_VALUE


def test_enum_str_validation_should_fail_for_decimal_when_expecting_str_value():
class MyEnum(Enum):
VALUE = '1'

v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

with pytest.raises(ValidationError):
v.validate_python(Decimal(1))


def test_enum_int_validation_should_fail_for_incorrect_decimal_value():
class MyEnum(Enum):
VALUE = 1

v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

with pytest.raises(ValidationError):
v.validate_python(Decimal(2))

with pytest.raises(ValidationError):
v.validate_python((1, 2))

with pytest.raises(ValidationError):
v.validate_python(Decimal(1.1))


def test_enum_int_validation_should_fail_for_plain_type_without_eq_checking():
class MyEnum(Enum):
VALUE = 1

class MyClass:
def __init__(self, value):
self.value = value

v = SchemaValidator(
core_schema.with_default_schema(
schema=core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())),
default=MyEnum.VALUE,
)
)

with pytest.raises(ValidationError):
v.validate_python(MyClass(1))


def support_custom_new_method() -> None:
"""Demonstrates support for custom new methods, as well as conceptually, multi-value enums without dependency on a 3rd party lib for testing."""

class Animal(Enum):
CAT = 'cat', 'meow'
DOG = 'dog', 'woof'

def __new__(cls, species: str, sound: str):
obj = object.__new__(cls)

obj._value_ = species
obj._all_values = (species, sound)

obj.species = species
obj.sound = sound

cls._value2member_map_[sound] = obj

return obj

v = SchemaValidator(core_schema.enum_schema(Animal, list(Animal.__members__.values())))
assert v.validate_python('cat') is Animal.CAT
assert v.validate_python('meow') is Animal.CAT
assert v.validate_python('dog') is Animal.DOG
assert v.validate_python('woof') is Animal.DOG
65 changes: 65 additions & 0 deletions tests/validators/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re
import sys
from copy import deepcopy
from decimal import Decimal
from typing import Any, Callable, Dict, List, Set, Tuple

import pytest
Expand Down Expand Up @@ -1312,3 +1314,66 @@ class OtherModel:
'ctx': {'class_name': 'MyModel'},
}
]


@pytest.mark.skipif(
sys.version_info >= (3, 13),
reason='Python 3.13+ enum initialization is different, see https://github.com/python/cpython/blob/ec610069637d56101896803a70d418a89afe0b4b/Lib/enum.py#L1159-L1163',
)
def test_model_with_enum_int_field_validation_should_succeed_for_any_type_equality_checks():
# GIVEN
from enum import Enum

class EnumClass(Enum):
enum_value = 1
enum_value_2 = 2
enum_value_3 = 3

class IntWrappable:
def __init__(self, value: int):
self.value = value

def __eq__(self, other: object) -> bool:
return self.value == other

class MyModel:
__slots__ = (
'__dict__',
'__pydantic_fields_set__',
'__pydantic_extra__',
'__pydantic_private__',
)
enum_field: EnumClass

# WHEN
v = SchemaValidator(
core_schema.model_schema(
MyModel,
core_schema.model_fields_schema(
{
'enum_field': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
'enum_field_2': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
'enum_field_3': core_schema.model_field(
core_schema.enum_schema(EnumClass, list(EnumClass.__members__.values()))
),
}
),
)
)

# THEN
v.validate_json('{"enum_field": 1, "enum_field_2": 2, "enum_field_3": 3}')
m = v.validate_python(
{
'enum_field': Decimal(1),
'enum_field_2': Decimal(2),
'enum_field_3': IntWrappable(3),
}
)
v.validate_assignment(m, 'enum_field', Decimal(1))
v.validate_assignment(m, 'enum_field_2', Decimal(2))
v.validate_assignment(m, 'enum_field_3', IntWrappable(3))
Loading