Skip to content

Commit a5838c7

Browse files
committed
fix(union_serializer): do not raise warnings in nested unions
In case unions of unions are used, this will bubble-up the errors rather than warning immediately. If no solution is found among all serializers by the top-level union, it will warn as before. Signed-off-by: Luka Peschke <[email protected]>
1 parent 9217019 commit a5838c7

File tree

2 files changed

+155
-7
lines changed

2 files changed

+155
-7
lines changed

src/serializers/type_serializers/union.rs

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,61 @@ impl UnionSerializer {
6565
}
6666
}
6767
}
68+
69+
fn _to_python(
70+
&self,
71+
value: &Bound<'_, PyAny>,
72+
include: Option<&Bound<'_, PyAny>>,
73+
exclude: Option<&Bound<'_, PyAny>>,
74+
extra: &Extra,
75+
) -> ToPythonExtractorResult {
76+
to_python_extractor(value, include, exclude, extra, &self.choices)
77+
}
6878
}
6979

7080
impl_py_gc_traverse!(UnionSerializer { choices });
7181

82+
#[derive(Debug)]
83+
enum ToPythonExtractorResult {
84+
Success(PyObject),
85+
Errors(SmallVec<[PyErr; SMALL_UNION_THRESHOLD]>),
86+
}
87+
88+
fn to_python_extractor(
89+
value: &Bound<'_, PyAny>,
90+
include: Option<&Bound<'_, PyAny>>,
91+
exclude: Option<&Bound<'_, PyAny>>,
92+
extra: &Extra,
93+
choices: &[CombinedSerializer],
94+
) -> ToPythonExtractorResult {
95+
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();
96+
97+
for comb_serializer in choices {
98+
match comb_serializer {
99+
CombinedSerializer::Union(union_serializer) => {
100+
match union_serializer._to_python(value, include, exclude, extra) {
101+
ToPythonExtractorResult::Errors(errs) => errors.extend(errs),
102+
ToPythonExtractorResult::Success(success) => return ToPythonExtractorResult::Success(success),
103+
}
104+
}
105+
CombinedSerializer::TaggedUnion(tagged_union_serializer) => {
106+
match tagged_union_serializer._to_python(value, include, exclude, extra) {
107+
ToPythonExtractorResult::Errors(errs) => errors.extend(errs),
108+
ToPythonExtractorResult::Success(success) => return ToPythonExtractorResult::Success(success),
109+
}
110+
}
111+
_ => {
112+
match comb_serializer.to_python(value, include, exclude, extra) {
113+
Ok(v) => return ToPythonExtractorResult::Success(v),
114+
Err(err) => errors.push(err),
115+
};
116+
}
117+
}
118+
}
119+
120+
ToPythonExtractorResult::Errors(errors)
121+
}
122+
72123
fn to_python(
73124
value: &Bound<'_, PyAny>,
74125
include: Option<&Bound<'_, PyAny>>,
@@ -80,14 +131,13 @@ fn to_python(
80131
// try the serializers in left to right order with error_on fallback=true
81132
let mut new_extra = extra.clone();
82133
new_extra.check = SerCheck::Strict;
83-
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();
84134

85-
for comb_serializer in choices {
86-
match comb_serializer.to_python(value, include, exclude, &new_extra) {
87-
Ok(v) => return Ok(v),
88-
Err(err) => errors.push(err),
89-
}
90-
}
135+
let res = to_python_extractor(value, include, exclude, &new_extra, choices);
136+
137+
let errors = match res {
138+
ToPythonExtractorResult::Success(obj) => return Ok(obj),
139+
ToPythonExtractorResult::Errors(errs) => errs,
140+
};
91141

92142
if retry_with_lax_check {
93143
new_extra.check = SerCheck::Lax;
@@ -392,6 +442,38 @@ impl TypeSerializer for TaggedUnionSerializer {
392442
}
393443

394444
impl TaggedUnionSerializer {
445+
fn _to_python(
446+
&self,
447+
value: &Bound<'_, PyAny>,
448+
include: Option<&Bound<'_, PyAny>>,
449+
exclude: Option<&Bound<'_, PyAny>>,
450+
extra: &Extra,
451+
) -> ToPythonExtractorResult {
452+
let mut new_extra = extra.clone();
453+
new_extra.check = SerCheck::Strict;
454+
455+
if let Some(tag) = self.get_discriminator_value(value, extra) {
456+
let tag_str = tag.to_string();
457+
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
458+
let serializer = &self.choices[serializer_index];
459+
460+
match serializer.to_python(value, include, exclude, &new_extra) {
461+
Ok(v) => return ToPythonExtractorResult::Success(v),
462+
Err(_) => {
463+
if self.retry_with_lax_check() {
464+
new_extra.check = SerCheck::Lax;
465+
if let Ok(v) = serializer.to_python(value, include, exclude, &new_extra) {
466+
return ToPythonExtractorResult::Success(v);
467+
}
468+
}
469+
}
470+
}
471+
}
472+
}
473+
474+
to_python_extractor(value, include, exclude, extra, &self.choices)
475+
}
476+
395477
fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option<Py<PyAny>> {
396478
let py = value.py();
397479
let discriminator_value = match &self.discriminator {

tests/serializers/test_union.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,3 +778,69 @@ class ModelB:
778778
model_b = ModelB(field=1)
779779
assert s.to_python(model_a) == {'field': 1, 'TAG': 'a'}
780780
assert s.to_python(model_b) == {'field': 1, 'TAG': 'b'}
781+
782+
783+
class ModelDog:
784+
def __init__(self, type_: Literal['dog']) -> None:
785+
self.type_ = 'dog'
786+
787+
788+
class ModelCat:
789+
def __init__(self, type_: Literal['cat']) -> None:
790+
self.type_ = 'cat'
791+
792+
793+
def test_union_of_unions_of_models() -> None:
794+
s = SchemaSerializer(
795+
core_schema.union_schema(
796+
[
797+
core_schema.union_schema(
798+
[
799+
core_schema.model_schema(
800+
cls=ModelA,
801+
schema=core_schema.model_fields_schema(
802+
fields={
803+
'a': core_schema.model_field(core_schema.str_schema()),
804+
'b': core_schema.model_field(core_schema.str_schema()),
805+
},
806+
),
807+
),
808+
core_schema.model_schema(
809+
cls=ModelB,
810+
schema=core_schema.model_fields_schema(
811+
fields={
812+
'c': core_schema.model_field(core_schema.str_schema()),
813+
'd': core_schema.model_field(core_schema.str_schema()),
814+
},
815+
),
816+
),
817+
]
818+
),
819+
core_schema.union_schema(
820+
[
821+
core_schema.model_schema(
822+
cls=ModelCat,
823+
schema=core_schema.model_fields_schema(
824+
fields={
825+
'type_': core_schema.model_field(core_schema.literal_schema(['cat'])),
826+
},
827+
),
828+
),
829+
core_schema.model_schema(
830+
cls=ModelDog,
831+
schema=core_schema.model_fields_schema(
832+
fields={
833+
'type_': core_schema.model_field(core_schema.literal_schema(['dog'])),
834+
},
835+
),
836+
),
837+
]
838+
),
839+
]
840+
)
841+
)
842+
843+
assert s.to_python(ModelA(a='a', b='b'), warnings='error') == {'a': 'a', 'b': 'b'}
844+
assert s.to_python(ModelB(c='c', d='d'), warnings='error') == {'c': 'c', 'd': 'd'}
845+
assert s.to_python(ModelCat(type_='cat'), warnings='error') == {'type_': 'cat'}
846+
assert s.to_python(ModelDog(type_='dog'), warnings='error') == {'type_': 'dog'}

0 commit comments

Comments
 (0)