Skip to content

Commit 5506e1f

Browse files
committed
feat: apply same logic to serde_serialize and add non-regression test
Signed-off-by: Luka Peschke <[email protected]>
1 parent 5c38506 commit 5506e1f

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

src/serializers/type_serializers/union.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ fn serde_serialize<S: serde::ser::Serializer>(
180180
}
181181
}
182182

183-
if retry_with_lax_check {
183+
// If extra.check is SerCheck::Strict, we're in a nested union
184+
if extra.check != SerCheck::Strict && retry_with_lax_check {
184185
new_extra.check = SerCheck::Lax;
185186
for comb_serializer in choices {
186187
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
@@ -189,8 +190,14 @@ fn serde_serialize<S: serde::ser::Serializer>(
189190
}
190191
}
191192

192-
for err in &errors {
193-
extra.warnings.custom_warning(err.to_string());
193+
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
194+
if extra.check == SerCheck::None {
195+
for err in &errors {
196+
extra.warnings.custom_warning(err.to_string());
197+
}
198+
} else {
199+
// NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors
200+
// will have to be returned here
194201
}
195202

196203
infer_serialize(value, serializer, include, exclude, extra)

tests/serializers/test_union.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,7 @@ def test_union_of_unions_of_models_with_tagged_union_invalid_variant(
958958
],
959959
)
960960
def test_union_of_unions_of_models_with_tagged_union_json_key_serialization(
961-
input: bool | int | float | str, expected: bytes
961+
input: dict[bool | int | float | str, str], expected: bytes
962962
) -> None:
963963
s = SchemaSerializer(
964964
core_schema.dict_schema(
@@ -973,3 +973,30 @@ def test_union_of_unions_of_models_with_tagged_union_json_key_serialization(
973973
)
974974

975975
assert s.to_json(input, warnings='error') == expected
976+
977+
978+
@pytest.mark.parametrize(
979+
'input,expected',
980+
[
981+
({'key': True}, b'{"key":true}'),
982+
({'key': 1}, b'{"key":1}'),
983+
({'key': 2.3}, b'{"key":2.3}'),
984+
({'key': 'a'}, b'{"key":"a"}'),
985+
],
986+
)
987+
def test_union_of_unions_of_models_with_tagged_union_json_serialization(
988+
input: dict[str, bool | int | float | str], expected: bytes
989+
) -> None:
990+
s = SchemaSerializer(
991+
core_schema.dict_schema(
992+
keys_schema=core_schema.str_schema(),
993+
values_schema=core_schema.union_schema(
994+
[
995+
core_schema.union_schema([core_schema.bool_schema(), core_schema.int_schema()]),
996+
core_schema.union_schema([core_schema.float_schema(), core_schema.str_schema()]),
997+
]
998+
),
999+
)
1000+
)
1001+
1002+
assert s.to_json(input, warnings='error') == expected

0 commit comments

Comments
 (0)