Skip to content

Commit 6472887

Browse files
WIP: Simplify shared union serializer logic (#1538)
Co-authored-by: David Hewitt <[email protected]>
1 parent 061711f commit 6472887

File tree

1 file changed

+84
-177
lines changed
  • src/serializers/type_serializers

1 file changed

+84
-177
lines changed

src/serializers/type_serializers/union.rs

Lines changed: 84 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ use std::borrow::Cow;
88
use crate::build_tools::py_schema_err;
99
use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD};
1010
use crate::definitions::DefinitionsBuilder;
11+
use crate::serializers::PydanticSerializationUnexpectedValue;
1112
use crate::tools::{truncate_safe_repr, SchemaDict};
12-
use crate::PydanticSerializationUnexpectedValue;
1313

1414
use super::{
1515
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerCheck,
@@ -70,22 +70,23 @@ impl UnionSerializer {
7070

7171
impl_py_gc_traverse!(UnionSerializer { choices });
7272

73-
fn to_python(
74-
value: &Bound<'_, PyAny>,
75-
include: Option<&Bound<'_, PyAny>>,
76-
exclude: Option<&Bound<'_, PyAny>>,
73+
fn union_serialize<S>(
74+
// if this returns `Ok(Some(v))`, we picked a union variant to serialize,
75+
// Or `Ok(None)` if we couldn't find a suitable variant to serialize
76+
// Finally, `Err(err)` if we encountered errors while trying to serialize
77+
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
7778
extra: &Extra,
7879
choices: &[CombinedSerializer],
7980
retry_with_lax_check: bool,
80-
) -> PyResult<PyObject> {
81+
) -> PyResult<Option<S>> {
8182
// try the serializers in left to right order with error_on fallback=true
8283
let mut new_extra = extra.clone();
8384
new_extra.check = SerCheck::Strict;
8485
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();
8586

8687
for comb_serializer in choices {
87-
match comb_serializer.to_python(value, include, exclude, &new_extra) {
88-
Ok(v) => return Ok(v),
88+
match selector(comb_serializer, &new_extra) {
89+
Ok(v) => return Ok(Some(v)),
8990
Err(err) => errors.push(err),
9091
}
9192
}
@@ -94,8 +95,8 @@ fn to_python(
9495
if extra.check != SerCheck::Strict && retry_with_lax_check {
9596
new_extra.check = SerCheck::Lax;
9697
for comb_serializer in choices {
97-
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
98-
return Ok(v);
98+
if let Ok(v) = selector(comb_serializer, &new_extra) {
99+
return Ok(Some(v));
99100
}
100101
}
101102
}
@@ -113,94 +114,45 @@ fn to_python(
113114
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
114115
}
115116

116-
infer_to_python(value, include, exclude, extra)
117+
Ok(None)
117118
}
118119

119-
fn json_key<'a>(
120-
key: &'a Bound<'_, PyAny>,
120+
fn tagged_union_serialize<S>(
121+
discriminator_value: Option<Py<PyAny>>,
122+
lookup: &HashMap<String, usize>,
123+
// if this returns `Ok(v)`, we picked a union variant to serialize, where
124+
// `S` is intermediate state which can be passed on to the finalizer
125+
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
121126
extra: &Extra,
122127
choices: &[CombinedSerializer],
123128
retry_with_lax_check: bool,
124-
) -> PyResult<Cow<'a, str>> {
129+
) -> PyResult<Option<S>> {
125130
let mut new_extra = extra.clone();
126131
new_extra.check = SerCheck::Strict;
127-
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();
128-
129-
for comb_serializer in choices {
130-
match comb_serializer.json_key(key, &new_extra) {
131-
Ok(v) => return Ok(v),
132-
Err(err) => errors.push(err),
133-
}
134-
}
135132

136-
// If extra.check is SerCheck::Strict, we're in a nested union
137-
if extra.check != SerCheck::Strict && retry_with_lax_check {
138-
new_extra.check = SerCheck::Lax;
139-
for comb_serializer in choices {
140-
if let Ok(v) = comb_serializer.json_key(key, &new_extra) {
141-
return Ok(v);
142-
}
143-
}
144-
}
145-
146-
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
147-
if extra.check == SerCheck::None {
148-
for err in &errors {
149-
extra.warnings.custom_warning(err.to_string());
150-
}
151-
}
152-
// Otherwise, if we've encountered errors, return them to the parent union, which should take
153-
// care of the formatting for us
154-
else if !errors.is_empty() {
155-
let message = errors.iter().map(ToString::to_string).collect::<Vec<_>>().join("\n");
156-
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
157-
}
158-
infer_json_key(key, extra)
159-
}
160-
161-
#[allow(clippy::too_many_arguments)]
162-
fn serde_serialize<S: serde::ser::Serializer>(
163-
value: &Bound<'_, PyAny>,
164-
serializer: S,
165-
include: Option<&Bound<'_, PyAny>>,
166-
exclude: Option<&Bound<'_, PyAny>>,
167-
extra: &Extra,
168-
choices: &[CombinedSerializer],
169-
retry_with_lax_check: bool,
170-
) -> Result<S::Ok, S::Error> {
171-
let py = value.py();
172-
let mut new_extra = extra.clone();
173-
new_extra.check = SerCheck::Strict;
174-
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();
175-
176-
for comb_serializer in choices {
177-
match comb_serializer.to_python(value, include, exclude, &new_extra) {
178-
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
179-
Err(err) => errors.push(err),
180-
}
181-
}
182-
183-
// If extra.check is SerCheck::Strict, we're in a nested union
184-
if extra.check != SerCheck::Strict && retry_with_lax_check {
185-
new_extra.check = SerCheck::Lax;
186-
for comb_serializer in choices {
187-
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
188-
return infer_serialize(v.bind(py), serializer, None, None, extra);
133+
if let Some(tag) = discriminator_value {
134+
let tag_str = tag.to_string();
135+
if let Some(&serializer_index) = lookup.get(&tag_str) {
136+
let selected_serializer = &choices[serializer_index];
137+
138+
match selector(selected_serializer, &new_extra) {
139+
Ok(v) => return Ok(Some(v)),
140+
Err(_) => {
141+
if retry_with_lax_check {
142+
new_extra.check = SerCheck::Lax;
143+
if let Ok(v) = selector(selected_serializer, &new_extra) {
144+
return Ok(Some(v));
145+
}
146+
}
147+
}
189148
}
190149
}
191150
}
192151

193-
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
194-
if extra.check == SerCheck::None {
195-
for err in &errors {
196-
extra.warnings.custom_warning(err.to_string());
197-
}
198-
} else {
199-
// NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors
200-
// will have to be returned here
201-
}
202-
203-
infer_serialize(value, serializer, include, exclude, extra)
152+
// if we haven't returned at this point, we should fallback to the union serializer
153+
// which preserves the historical expectation that we do our best with serialization
154+
// even if that means we resort to inference
155+
union_serialize(selector, extra, choices, retry_with_lax_check)
204156
}
205157

206158
impl TypeSerializer for UnionSerializer {
@@ -211,18 +163,23 @@ impl TypeSerializer for UnionSerializer {
211163
exclude: Option<&Bound<'_, PyAny>>,
212164
extra: &Extra,
213165
) -> PyResult<PyObject> {
214-
to_python(
215-
value,
216-
include,
217-
exclude,
166+
union_serialize(
167+
|comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra),
218168
extra,
219169
&self.choices,
220170
self.retry_with_lax_check(),
221-
)
171+
)?
172+
.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok)
222173
}
223174

224175
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
225-
json_key(key, extra, &self.choices, self.retry_with_lax_check())
176+
union_serialize(
177+
|comb_serializer, new_extra| comb_serializer.json_key(key, new_extra),
178+
extra,
179+
&self.choices,
180+
self.retry_with_lax_check(),
181+
)?
182+
.map_or_else(|| infer_json_key(key, extra), Ok)
226183
}
227184

228185
fn serde_serialize<S: serde::ser::Serializer>(
@@ -233,15 +190,16 @@ impl TypeSerializer for UnionSerializer {
233190
exclude: Option<&Bound<'_, PyAny>>,
234191
extra: &Extra,
235192
) -> Result<S::Ok, S::Error> {
236-
serde_serialize(
237-
value,
238-
serializer,
239-
include,
240-
exclude,
193+
match union_serialize(
194+
|comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra),
241195
extra,
242196
&self.choices,
243197
self.retry_with_lax_check(),
244-
)
198+
) {
199+
Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra),
200+
Ok(None) => infer_serialize(value, serializer, include, exclude, extra),
201+
Err(err) => Err(serde::ser::Error::custom(err.to_string())),
202+
}
245203
}
246204

247205
fn get_name(&self) -> &str {
@@ -309,62 +267,29 @@ impl TypeSerializer for TaggedUnionSerializer {
309267
exclude: Option<&Bound<'_, PyAny>>,
310268
extra: &Extra,
311269
) -> PyResult<PyObject> {
312-
let mut new_extra = extra.clone();
313-
new_extra.check = SerCheck::Strict;
314-
315-
if let Some(tag) = self.get_discriminator_value(value, extra) {
316-
let tag_str = tag.to_string();
317-
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
318-
let serializer = &self.choices[serializer_index];
319-
320-
match serializer.to_python(value, include, exclude, &new_extra) {
321-
Ok(v) => return Ok(v),
322-
Err(_) => {
323-
if self.retry_with_lax_check() {
324-
new_extra.check = SerCheck::Lax;
325-
if let Ok(v) = serializer.to_python(value, include, exclude, &new_extra) {
326-
return Ok(v);
327-
}
328-
}
329-
}
330-
}
331-
}
332-
}
333-
334-
to_python(
335-
value,
336-
include,
337-
exclude,
270+
tagged_union_serialize(
271+
self.get_discriminator_value(value, extra),
272+
&self.lookup,
273+
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
274+
comb_serializer.to_python(value, include, exclude, new_extra)
275+
},
338276
extra,
339277
&self.choices,
340278
self.retry_with_lax_check(),
341-
)
279+
)?
280+
.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok)
342281
}
343282

344283
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
345-
let mut new_extra = extra.clone();
346-
new_extra.check = SerCheck::Strict;
347-
348-
if let Some(tag) = self.get_discriminator_value(key, extra) {
349-
let tag_str = tag.to_string();
350-
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
351-
let serializer = &self.choices[serializer_index];
352-
353-
match serializer.json_key(key, &new_extra) {
354-
Ok(v) => return Ok(v),
355-
Err(_) => {
356-
if self.retry_with_lax_check() {
357-
new_extra.check = SerCheck::Lax;
358-
if let Ok(v) = serializer.json_key(key, &new_extra) {
359-
return Ok(v);
360-
}
361-
}
362-
}
363-
}
364-
}
365-
}
366-
367-
json_key(key, extra, &self.choices, self.retry_with_lax_check())
284+
tagged_union_serialize(
285+
self.get_discriminator_value(key, extra),
286+
&self.lookup,
287+
|comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra),
288+
extra,
289+
&self.choices,
290+
self.retry_with_lax_check(),
291+
)?
292+
.map_or_else(|| infer_json_key(key, extra), Ok)
368293
}
369294

370295
fn serde_serialize<S: serde::ser::Serializer>(
@@ -375,38 +300,20 @@ impl TypeSerializer for TaggedUnionSerializer {
375300
exclude: Option<&Bound<'_, PyAny>>,
376301
extra: &Extra,
377302
) -> Result<S::Ok, S::Error> {
378-
let py = value.py();
379-
let mut new_extra = extra.clone();
380-
new_extra.check = SerCheck::Strict;
381-
382-
if let Some(tag) = self.get_discriminator_value(value, extra) {
383-
let tag_str = tag.to_string();
384-
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
385-
let selected_serializer = &self.choices[serializer_index];
386-
387-
match selected_serializer.to_python(value, include, exclude, &new_extra) {
388-
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
389-
Err(_) => {
390-
if self.retry_with_lax_check() {
391-
new_extra.check = SerCheck::Lax;
392-
if let Ok(v) = selected_serializer.to_python(value, include, exclude, &new_extra) {
393-
return infer_serialize(v.bind(py), serializer, None, None, extra);
394-
}
395-
}
396-
}
397-
}
398-
}
399-
}
400-
401-
serde_serialize(
402-
value,
403-
serializer,
404-
include,
405-
exclude,
303+
match tagged_union_serialize(
304+
None,
305+
&self.lookup,
306+
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
307+
comb_serializer.to_python(value, include, exclude, new_extra)
308+
},
406309
extra,
407310
&self.choices,
408311
self.retry_with_lax_check(),
409-
)
312+
) {
313+
Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra),
314+
Ok(None) => infer_serialize(value, serializer, include, exclude, extra),
315+
Err(err) => Err(serde::ser::Error::custom(err.to_string())),
316+
}
410317
}
411318

412319
fn get_name(&self) -> &str {

0 commit comments

Comments
 (0)