Skip to content

Commit 48de7c6

Browse files
committed
validate values against their json form
1 parent 27700b3 commit 48de7c6

File tree

3 files changed

+53
-18
lines changed

3 files changed

+53
-18
lines changed

src/serializers/config.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use crate::tools::SchemaDict;
1414

1515
use super::errors::py_err_se_err;
1616

17-
#[derive(Debug, Clone)]
17+
#[derive(Debug, Clone, Default)]
1818
#[allow(clippy::struct_field_names)]
1919
pub(crate) struct SerializationConfig {
2020
pub timedelta_mode: TimedeltaMode,
@@ -57,6 +57,15 @@ macro_rules! serialization_mode {
5757
$($variant,)*
5858
}
5959

60+
impl std::fmt::Display for $name {
61+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
62+
match self {
63+
$(Self::$variant => write!(f, $value),)*
64+
}
65+
66+
}
67+
}
68+
6069
impl FromStr for $name {
6170
type Err = PyErr;
6271

src/serializers/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::definitions::{Definitions, DefinitionsBuilder};
99
use crate::py_gc::PyGcTraverse;
1010

1111
pub(crate) use config::BytesMode;
12-
use config::SerializationConfig;
12+
pub(crate) use config::SerializationConfig;
1313
pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
1414
use extra::{CollectWarnings, SerRecursionState, WarningsMode};
1515
pub(crate) use extra::{DuckTypingSerMode, Extra, SerMode, SerializationState};

src/validators/enum_.rs

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,17 @@ use std::marker::PhantomData;
44
use pyo3::exceptions::PyTypeError;
55
use pyo3::intern;
66
use pyo3::prelude::*;
7-
use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString, PyTuple, PyType};
7+
use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString, PyType};
88

99
use crate::build_tools::{is_strict, py_schema_err};
1010
use crate::errors::{ErrorType, ValError, ValResult};
1111
use crate::input::Input;
12+
use crate::serializers::{to_jsonable_python, SerializationConfig};
1213
use crate::tools::{safe_repr, SchemaDict};
1314

1415
use super::is_instance::class_repr;
1516
use super::literal::{expected_repr_name, LiteralLookup};
17+
use super::InputType;
1618
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Exactness, ValidationState, Validator};
1719

1820
#[derive(Debug, Clone)]
@@ -33,36 +35,55 @@ impl BuildValidator for BuildEnumValidator {
3335

3436
let py = schema.py();
3537
let value_str = intern!(py, "value");
36-
let mut expected: Vec<(Bound<'_, PyAny>, PyObject)> = members
38+
let expected_py: Vec<(Bound<'_, PyAny>, PyObject)> = members
3739
.iter()
3840
.map(|v| Ok((v.getattr(value_str)?, v.into())))
3941
.collect::<PyResult<_>>()?;
42+
let ser_config = SerializationConfig::from_config(config).unwrap_or_default();
43+
let expected_json: Vec<(Bound<'_, PyAny>, PyObject)> = members
44+
.iter()
45+
.map(|v| {
46+
Ok((
47+
to_jsonable_python(
48+
py,
49+
&v.getattr(value_str)?,
50+
None,
51+
None,
52+
false,
53+
false,
54+
false,
55+
&ser_config.timedelta_mode.to_string(),
56+
&ser_config.bytes_mode.to_string(),
57+
&ser_config.inf_nan_mode.to_string(),
58+
false,
59+
None,
60+
true,
61+
None,
62+
)?
63+
.into_bound(py),
64+
v.into(),
65+
))
66+
})
67+
.collect::<PyResult<_>>()?;
4068

41-
let repr_args: Vec<String> = expected
69+
let repr_args: Vec<String> = expected_py
4270
.iter()
4371
.map(|(k, _)| k.repr()?.extract())
4472
.collect::<PyResult<_>>()?;
4573

46-
let mut addition = vec![];
47-
for (k, v) in &expected {
48-
if let Ok(ss) = k.downcast::<PyTuple>() {
49-
let list = ss.to_list();
50-
addition.push((list.into_any(), v.clone()));
51-
}
52-
}
53-
expected.append(&mut addition);
54-
5574
let class: Bound<PyType> = schema.get_as_req(intern!(py, "cls"))?;
5675
let class_repr = class_repr(schema, &class)?;
5776

58-
let lookup = LiteralLookup::new(py, expected.into_iter())?;
77+
let py_lookup = LiteralLookup::new(py, expected_py.into_iter())?;
78+
let json_lookup = LiteralLookup::new(py, expected_json.into_iter())?;
5979

6080
macro_rules! build {
6181
($vv:ty, $name_prefix:literal) => {
6282
EnumValidator {
6383
phantom: PhantomData::<$vv>,
6484
class: class.clone().into(),
65-
lookup,
85+
py_lookup,
86+
json_lookup,
6687
missing: schema.get_as(intern!(py, "missing"))?,
6788
expected_repr: expected_repr_name(repr_args, "").0,
6889
strict: is_strict(schema, config)?,
@@ -96,7 +117,8 @@ pub trait EnumValidateValue: std::fmt::Debug + Clone + Send + Sync {
96117
pub struct EnumValidator<T: EnumValidateValue> {
97118
phantom: PhantomData<T>,
98119
class: Py<PyType>,
99-
lookup: LiteralLookup<PyObject>,
120+
py_lookup: LiteralLookup<PyObject>,
121+
json_lookup: LiteralLookup<PyObject>,
100122
missing: Option<PyObject>,
101123
expected_repr: String,
102124
strict: bool,
@@ -129,7 +151,11 @@ impl<T: EnumValidateValue> Validator for EnumValidator<T> {
129151

130152
state.floor_exactness(Exactness::Lax);
131153

132-
if let Some(v) = T::validate_value(py, input, &self.lookup, strict)? {
154+
let lookup = match state.extra().input_type {
155+
InputType::Json => &self.json_lookup,
156+
_ => &self.py_lookup,
157+
};
158+
if let Some(v) = T::validate_value(py, input, lookup, strict)? {
133159
return Ok(v);
134160
} else if let Ok(res) = class.as_unbound().call1(py, (input.as_python(),)) {
135161
return Ok(res);

0 commit comments

Comments
 (0)