Skip to content

Commit fd81a75

Browse files
Try each option in union serializer before inference (#1398)
1 parent 863640b commit fd81a75

File tree

4 files changed

+57
-13
lines changed

4 files changed

+57
-13
lines changed

src/serializers/type_serializers/union.rs

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
use pyo3::intern;
22
use pyo3::prelude::*;
33
use pyo3::types::{PyDict, PyList, PyTuple};
4+
use smallvec::SmallVec;
45
use std::borrow::Cow;
56

67
use crate::build_tools::py_schema_err;
78
use crate::definitions::DefinitionsBuilder;
8-
use crate::tools::SchemaDict;
9+
use crate::tools::{SchemaDict, UNION_ERR_SMALLVEC_CAPACITY};
910
use crate::PydanticSerializationUnexpectedValue;
1011

1112
use super::{
12-
infer_json_key, infer_serialize, infer_to_python, py_err_se_err, BuildSerializer, CombinedSerializer, Extra,
13-
SerCheck, TypeSerializer,
13+
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerCheck,
14+
TypeSerializer,
1415
};
1516

1617
#[derive(Debug, Clone)]
@@ -78,13 +79,14 @@ impl TypeSerializer for UnionSerializer {
7879
// try the serializers in left to right order with error_on fallback=true
7980
let mut new_extra = extra.clone();
8081
new_extra.check = SerCheck::Strict;
82+
let mut errors: SmallVec<[PyErr; UNION_ERR_SMALLVEC_CAPACITY]> = SmallVec::new();
8183

8284
for comb_serializer in &self.choices {
8385
match comb_serializer.to_python(value, include, exclude, &new_extra) {
8486
Ok(v) => return Ok(v),
8587
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(value.py()) {
8688
true => (),
87-
false => return Err(err),
89+
false => errors.push(err),
8890
},
8991
}
9092
}
@@ -95,25 +97,31 @@ impl TypeSerializer for UnionSerializer {
9597
Ok(v) => return Ok(v),
9698
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(value.py()) {
9799
true => (),
98-
false => return Err(err),
100+
false => errors.push(err),
99101
},
100102
}
101103
}
102104
}
103105

106+
for err in &errors {
107+
extra.warnings.custom_warning(err.to_string());
108+
}
109+
104110
extra.warnings.on_fallback_py(self.get_name(), value, extra)?;
105111
infer_to_python(value, include, exclude, extra)
106112
}
107113

108114
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
109115
let mut new_extra = extra.clone();
110116
new_extra.check = SerCheck::Strict;
117+
let mut errors: SmallVec<[PyErr; UNION_ERR_SMALLVEC_CAPACITY]> = SmallVec::new();
118+
111119
for comb_serializer in &self.choices {
112120
match comb_serializer.json_key(key, &new_extra) {
113121
Ok(v) => return Ok(v),
114122
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(key.py()) {
115123
true => (),
116-
false => return Err(err),
124+
false => errors.push(err),
117125
},
118126
}
119127
}
@@ -124,12 +132,16 @@ impl TypeSerializer for UnionSerializer {
124132
Ok(v) => return Ok(v),
125133
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(key.py()) {
126134
true => (),
127-
false => return Err(err),
135+
false => errors.push(err),
128136
},
129137
}
130138
}
131139
}
132140

141+
for err in &errors {
142+
extra.warnings.custom_warning(err.to_string());
143+
}
144+
133145
extra.warnings.on_fallback_py(self.get_name(), key, extra)?;
134146
infer_json_key(key, extra)
135147
}
@@ -145,12 +157,14 @@ impl TypeSerializer for UnionSerializer {
145157
let py = value.py();
146158
let mut new_extra = extra.clone();
147159
new_extra.check = SerCheck::Strict;
160+
let mut errors: SmallVec<[PyErr; UNION_ERR_SMALLVEC_CAPACITY]> = SmallVec::new();
161+
148162
for comb_serializer in &self.choices {
149163
match comb_serializer.to_python(value, include, exclude, &new_extra) {
150164
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
151-
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(py) {
165+
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(value.py()) {
152166
true => (),
153-
false => return Err(py_err_se_err(err)),
167+
false => errors.push(err),
154168
},
155169
}
156170
}
@@ -159,14 +173,18 @@ impl TypeSerializer for UnionSerializer {
159173
for comb_serializer in &self.choices {
160174
match comb_serializer.to_python(value, include, exclude, &new_extra) {
161175
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
162-
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(py) {
176+
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(value.py()) {
163177
true => (),
164-
false => return Err(py_err_se_err(err)),
178+
false => errors.push(err),
165179
},
166180
}
167181
}
168182
}
169183

184+
for err in &errors {
185+
extra.warnings.custom_warning(err.to_string());
186+
}
187+
170188
extra.warnings.on_fallback_ser::<S>(self.get_name(), value, extra)?;
171189
infer_serialize(value, serializer, include, exclude, extra)
172190
}

src/tools.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,5 @@ pub(crate) fn new_py_string<'py>(py: Python<'py>, s: &str, cache_str: StringCach
146146
pystring_fast_new(py, s, ascii_only)
147147
}
148148
}
149+
150+
pub(crate) const UNION_ERR_SMALLVEC_CAPACITY: usize = 4;

src/validators/union.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::errors::{ErrorType, ToErrorValue, ValError, ValLineError, ValResult};
1212
use crate::input::{BorrowInput, Input, ValidatedDict};
1313
use crate::lookup_key::LookupKey;
1414
use crate::py_gc::PyGcTraverse;
15-
use crate::tools::SchemaDict;
15+
use crate::tools::{SchemaDict, UNION_ERR_SMALLVEC_CAPACITY};
1616

1717
use super::custom_error::CustomError;
1818
use super::literal::LiteralLookup;
@@ -249,7 +249,7 @@ struct ChoiceLineErrors<'a> {
249249

250250
enum MaybeErrors<'a> {
251251
Custom(&'a CustomError),
252-
Errors(SmallVec<[ChoiceLineErrors<'a>; 4]>),
252+
Errors(SmallVec<[ChoiceLineErrors<'a>; UNION_ERR_SMALLVEC_CAPACITY]>),
253253
}
254254

255255
impl<'a> MaybeErrors<'a> {

tests/serializers/test_union.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,3 +626,27 @@ def test_union_serializer_picks_exact_type_over_subclass_json(
626626
)
627627
assert s.to_python(input_value, mode='json') == expected_value
628628
assert s.to_json(input_value) == json.dumps(expected_value).encode()
629+
630+
631+
def test_custom_serializer() -> None:
632+
s = SchemaSerializer(
633+
core_schema.union_schema(
634+
[
635+
core_schema.dict_schema(
636+
keys_schema=core_schema.any_schema(),
637+
values_schema=core_schema.any_schema(),
638+
serialization=core_schema.plain_serializer_function_ser_schema(lambda x: x['id']),
639+
),
640+
core_schema.list_schema(
641+
items_schema=core_schema.dict_schema(
642+
keys_schema=core_schema.any_schema(),
643+
values_schema=core_schema.any_schema(),
644+
serialization=core_schema.plain_serializer_function_ser_schema(lambda x: x['id']),
645+
)
646+
),
647+
]
648+
)
649+
)
650+
print(s)
651+
assert s.to_python([{'id': 1}, {'id': 2}]) == [1, 2]
652+
assert s.to_python({'id': 1}) == 1

0 commit comments

Comments
 (0)