Skip to content

Commit 2419981

Browse files
authored
fix(union_serializer): do not raise warnings in nested unions (#1513)
Signed-off-by: Luka Peschke <[email protected]>
1 parent 4cb82bf commit 2419981

File tree

2 files changed

+259
-10
lines changed

2 files changed

+259
-10
lines changed

src/serializers/type_serializers/union.rs

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::build_tools::py_schema_err;
99
use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD};
1010
use crate::definitions::DefinitionsBuilder;
1111
use crate::tools::{truncate_safe_repr, SchemaDict};
12+
use crate::PydanticSerializationUnexpectedValue;
1213

1314
use super::{
1415
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerCheck,
@@ -89,7 +90,8 @@ fn to_python(
8990
}
9091
}
9192

92-
if retry_with_lax_check {
93+
// If extra.check is SerCheck::Strict, we're in a nested union
94+
if extra.check != SerCheck::Strict && retry_with_lax_check {
9395
new_extra.check = SerCheck::Lax;
9496
for comb_serializer in choices {
9597
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
@@ -98,8 +100,17 @@ fn to_python(
98100
}
99101
}
100102

101-
for err in &errors {
102-
extra.warnings.custom_warning(err.to_string());
103+
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
104+
if extra.check == SerCheck::None {
105+
for err in &errors {
106+
extra.warnings.custom_warning(err.to_string());
107+
}
108+
}
109+
// Otherwise, if we've encountered errors, return them to the parent union, which should take
110+
// care of the formatting for us
111+
else if !errors.is_empty() {
112+
let message = errors.iter().map(ToString::to_string).collect::<Vec<_>>().join("\n");
113+
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
103114
}
104115

105116
infer_to_python(value, include, exclude, extra)
@@ -122,7 +133,8 @@ fn json_key<'a>(
122133
}
123134
}
124135

125-
if retry_with_lax_check {
136+
// If extra.check is SerCheck::Strict, we're in a nested union
137+
if extra.check != SerCheck::Strict && retry_with_lax_check {
126138
new_extra.check = SerCheck::Lax;
127139
for comb_serializer in choices {
128140
if let Ok(v) = comb_serializer.json_key(key, &new_extra) {
@@ -131,10 +143,18 @@ fn json_key<'a>(
131143
}
132144
}
133145

134-
for err in &errors {
135-
extra.warnings.custom_warning(err.to_string());
146+
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
147+
if extra.check == SerCheck::None {
148+
for err in &errors {
149+
extra.warnings.custom_warning(err.to_string());
150+
}
151+
}
152+
// Otherwise, if we've encountered errors, return them to the parent union, which should take
153+
// care of the formatting for us
154+
else if !errors.is_empty() {
155+
let message = errors.iter().map(ToString::to_string).collect::<Vec<_>>().join("\n");
156+
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
136157
}
137-
138158
infer_json_key(key, extra)
139159
}
140160

@@ -160,7 +180,8 @@ fn serde_serialize<S: serde::ser::Serializer>(
160180
}
161181
}
162182

163-
if retry_with_lax_check {
183+
// If extra.check is SerCheck::Strict, we're in a nested union
184+
if extra.check != SerCheck::Strict && retry_with_lax_check {
164185
new_extra.check = SerCheck::Lax;
165186
for comb_serializer in choices {
166187
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
@@ -169,8 +190,14 @@ fn serde_serialize<S: serde::ser::Serializer>(
169190
}
170191
}
171192

172-
for err in &errors {
173-
extra.warnings.custom_warning(err.to_string());
193+
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
194+
if extra.check == SerCheck::None {
195+
for err in &errors {
196+
extra.warnings.custom_warning(err.to_string());
197+
}
198+
} else {
199+
// NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors
200+
// will have to be returned here
174201
}
175202

176203
infer_serialize(value, serializer, include, exclude, extra)

tests/serializers/test_union.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import dataclasses
24
import json
35
import uuid
@@ -778,3 +780,223 @@ class ModelB:
778780
model_b = ModelB(field=1)
779781
assert s.to_python(model_a) == {'field': 1, 'TAG': 'a'}
780782
assert s.to_python(model_b) == {'field': 1, 'TAG': 'b'}
783+
784+
785+
class ModelDog:
786+
def __init__(self, type_: Literal['dog']) -> None:
787+
self.type_ = 'dog'
788+
789+
790+
class ModelCat:
791+
def __init__(self, type_: Literal['cat']) -> None:
792+
self.type_ = 'cat'
793+
794+
795+
class ModelAlien:
796+
def __init__(self, type_: Literal['alien']) -> None:
797+
self.type_ = 'alien'
798+
799+
800+
@pytest.fixture
801+
def model_a_b_union_schema() -> core_schema.UnionSchema:
802+
return core_schema.union_schema(
803+
[
804+
core_schema.model_schema(
805+
cls=ModelA,
806+
schema=core_schema.model_fields_schema(
807+
fields={
808+
'a': core_schema.model_field(core_schema.str_schema()),
809+
'b': core_schema.model_field(core_schema.str_schema()),
810+
},
811+
),
812+
),
813+
core_schema.model_schema(
814+
cls=ModelB,
815+
schema=core_schema.model_fields_schema(
816+
fields={
817+
'c': core_schema.model_field(core_schema.str_schema()),
818+
'd': core_schema.model_field(core_schema.str_schema()),
819+
},
820+
),
821+
),
822+
]
823+
)
824+
825+
826+
@pytest.fixture
827+
def union_of_unions_schema(model_a_b_union_schema: core_schema.UnionSchema) -> core_schema.UnionSchema:
828+
return core_schema.union_schema(
829+
[
830+
model_a_b_union_schema,
831+
core_schema.union_schema(
832+
[
833+
core_schema.model_schema(
834+
cls=ModelCat,
835+
schema=core_schema.model_fields_schema(
836+
fields={
837+
'type_': core_schema.model_field(core_schema.literal_schema(['cat'])),
838+
},
839+
),
840+
),
841+
core_schema.model_schema(
842+
cls=ModelDog,
843+
schema=core_schema.model_fields_schema(
844+
fields={
845+
'type_': core_schema.model_field(core_schema.literal_schema(['dog'])),
846+
},
847+
),
848+
),
849+
]
850+
),
851+
]
852+
)
853+
854+
855+
@pytest.mark.parametrize(
856+
'input,expected',
857+
[
858+
(ModelA(a='a', b='b'), {'a': 'a', 'b': 'b'}),
859+
(ModelB(c='c', d='d'), {'c': 'c', 'd': 'd'}),
860+
(ModelCat(type_='cat'), {'type_': 'cat'}),
861+
(ModelDog(type_='dog'), {'type_': 'dog'}),
862+
],
863+
)
864+
def test_union_of_unions_of_models(union_of_unions_schema: core_schema.UnionSchema, input: Any, expected: Any) -> None:
865+
s = SchemaSerializer(union_of_unions_schema)
866+
assert s.to_python(input, warnings='error') == expected
867+
868+
869+
def test_union_of_unions_of_models_invalid_variant(union_of_unions_schema: core_schema.UnionSchema) -> None:
870+
s = SchemaSerializer(union_of_unions_schema)
871+
# All warnings should be available
872+
messages = [
873+
'Expected `ModelA` but got `ModelAlien`',
874+
'Expected `ModelB` but got `ModelAlien`',
875+
'Expected `ModelCat` but got `ModelAlien`',
876+
'Expected `ModelDog` but got `ModelAlien`',
877+
]
878+
879+
with warnings.catch_warnings(record=True) as w:
880+
warnings.simplefilter('always')
881+
s.to_python(ModelAlien(type_='alien'))
882+
for m in messages:
883+
assert m in str(w[0].message)
884+
885+
886+
@pytest.fixture
887+
def tagged_union_of_unions_schema(model_a_b_union_schema: core_schema.UnionSchema) -> core_schema.UnionSchema:
888+
return core_schema.union_schema(
889+
[
890+
model_a_b_union_schema,
891+
core_schema.tagged_union_schema(
892+
discriminator='type_',
893+
choices={
894+
'cat': core_schema.model_schema(
895+
cls=ModelCat,
896+
schema=core_schema.model_fields_schema(
897+
fields={
898+
'type_': core_schema.model_field(core_schema.literal_schema(['cat'])),
899+
},
900+
),
901+
),
902+
'dog': core_schema.model_schema(
903+
cls=ModelDog,
904+
schema=core_schema.model_fields_schema(
905+
fields={
906+
'type_': core_schema.model_field(core_schema.literal_schema(['dog'])),
907+
},
908+
),
909+
),
910+
},
911+
),
912+
]
913+
)
914+
915+
916+
@pytest.mark.parametrize(
917+
'input,expected',
918+
[
919+
(ModelA(a='a', b='b'), {'a': 'a', 'b': 'b'}),
920+
(ModelB(c='c', d='d'), {'c': 'c', 'd': 'd'}),
921+
(ModelCat(type_='cat'), {'type_': 'cat'}),
922+
(ModelDog(type_='dog'), {'type_': 'dog'}),
923+
],
924+
)
925+
def test_union_of_unions_of_models_with_tagged_union(
926+
tagged_union_of_unions_schema: core_schema.UnionSchema, input: Any, expected: Any
927+
) -> None:
928+
s = SchemaSerializer(tagged_union_of_unions_schema)
929+
assert s.to_python(input, warnings='error') == expected
930+
931+
932+
def test_union_of_unions_of_models_with_tagged_union_invalid_variant(
933+
tagged_union_of_unions_schema: core_schema.UnionSchema,
934+
) -> None:
935+
s = SchemaSerializer(tagged_union_of_unions_schema)
936+
# All warnings should be available
937+
messages = [
938+
'Expected `ModelA` but got `ModelAlien`',
939+
'Expected `ModelB` but got `ModelAlien`',
940+
'Expected `ModelCat` but got `ModelAlien`',
941+
'Expected `ModelDog` but got `ModelAlien`',
942+
]
943+
944+
with warnings.catch_warnings(record=True) as w:
945+
warnings.simplefilter('always')
946+
s.to_python(ModelAlien(type_='alien'))
947+
for m in messages:
948+
assert m in str(w[0].message)
949+
950+
951+
@pytest.mark.parametrize(
952+
'input,expected',
953+
[
954+
({True: '1'}, b'{"true":"1"}'),
955+
({1: '1'}, b'{"1":"1"}'),
956+
({2.3: '1'}, b'{"2.3":"1"}'),
957+
({'a': 'b'}, b'{"a":"b"}'),
958+
],
959+
)
960+
def test_union_of_unions_of_models_with_tagged_union_json_key_serialization(
961+
input: dict[bool | int | float | str, str], expected: bytes
962+
) -> None:
963+
s = SchemaSerializer(
964+
core_schema.dict_schema(
965+
keys_schema=core_schema.union_schema(
966+
[
967+
core_schema.union_schema([core_schema.bool_schema(), core_schema.int_schema()]),
968+
core_schema.union_schema([core_schema.float_schema(), core_schema.str_schema()]),
969+
]
970+
),
971+
values_schema=core_schema.str_schema(),
972+
)
973+
)
974+
975+
assert s.to_json(input, warnings='error') == expected
976+
977+
978+
@pytest.mark.parametrize(
979+
'input,expected',
980+
[
981+
({'key': True}, b'{"key":true}'),
982+
({'key': 1}, b'{"key":1}'),
983+
({'key': 2.3}, b'{"key":2.3}'),
984+
({'key': 'a'}, b'{"key":"a"}'),
985+
],
986+
)
987+
def test_union_of_unions_of_models_with_tagged_union_json_serialization(
988+
input: dict[str, bool | int | float | str], expected: bytes
989+
) -> None:
990+
s = SchemaSerializer(
991+
core_schema.dict_schema(
992+
keys_schema=core_schema.str_schema(),
993+
values_schema=core_schema.union_schema(
994+
[
995+
core_schema.union_schema([core_schema.bool_schema(), core_schema.int_schema()]),
996+
core_schema.union_schema([core_schema.float_schema(), core_schema.str_schema()]),
997+
]
998+
),
999+
)
1000+
)
1001+
1002+
assert s.to_json(input, warnings='error') == expected

0 commit comments

Comments
 (0)