Skip to content

Commit a3f13c7

Browse files
authored
fix wrap serializer breaking union serialization in presence of extra fields (#1530)
1 parent cd0346d commit a3f13c7

File tree

4 files changed

+124
-57
lines changed

4 files changed

+124
-57
lines changed

src/serializers/extra.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ impl<'a> Extra<'a> {
198198
pub fn serialize_infer<'py>(&'py self, value: &'py Bound<'py, PyAny>) -> super::infer::SerializeInfer<'py> {
199199
super::infer::SerializeInfer::new(value, None, None, self)
200200
}
201+
202+
pub(crate) fn model_type_name(&self) -> Option<Bound<'a, PyString>> {
203+
self.model.and_then(|model| model.get_type().name().ok())
204+
}
201205
}
202206

203207
#[derive(Clone, Copy, PartialEq, Eq)]

src/serializers/fields.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,15 @@ impl GeneralFieldsSerializer {
200200
};
201201
output_dict.set_item(key, value)?;
202202
} else if field_extra.check == SerCheck::Strict {
203-
return Err(PydanticSerializationUnexpectedValue::new_err(None));
203+
let type_name = field_extra.model_type_name();
204+
return Err(PydanticSerializationUnexpectedValue::new_err(Some(format!(
205+
"Unexpected field `{key}`{for_type_name}",
206+
for_type_name = if let Some(type_name) = type_name {
207+
format!(" for type `{type_name}`")
208+
} else {
209+
String::new()
210+
},
211+
))));
204212
}
205213
}
206214
}
@@ -212,22 +220,15 @@ impl GeneralFieldsSerializer {
212220
&& self.required_fields > used_req_fields
213221
{
214222
let required_fields = self.required_fields;
215-
let type_name = match extra.model {
216-
Some(model) => model
217-
.get_type()
218-
.qualname()
219-
.ok()
220-
.unwrap_or_else(|| PyString::new_bound(py, "<unknown python object>"))
221-
.to_string(),
222-
None => "<unknown python object>".to_string(),
223-
};
223+
let type_name = extra.model_type_name();
224224
let field_value = match extra.model {
225225
Some(model) => truncate_safe_repr(model, Some(100)),
226226
None => "<unknown python object>".to_string(),
227227
};
228228

229229
Err(PydanticSerializationUnexpectedValue::new_err(Some(format!(
230-
"Expected {required_fields} fields but got {used_req_fields} for type `{type_name}` with value `{field_value}` - serialized value may not be as expected."
230+
"Expected {required_fields} fields but got {used_req_fields}{for_type_name} with value `{field_value}` - serialized value may not be as expected.",
231+
for_type_name = if let Some(type_name) = type_name { format!(" for type `{type_name}`") } else { String::new() },
231232
))))
232233
} else {
233234
Ok(output_dict)

src/serializers/type_serializers/function.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,13 @@ impl FunctionPlainSerializer {
179179
.expect("fallback_serializer unexpectedly none")
180180
.as_ref()
181181
}
182+
183+
fn retry_with_lax_check(&self) -> bool {
184+
self.fallback_serializer
185+
.as_ref()
186+
.map_or(false, |f| f.retry_with_lax_check())
187+
|| self.return_serializer.retry_with_lax_check()
188+
}
182189
}
183190

184191
fn on_error(py: Python, err: PyErr, function_name: &str, extra: &Extra) -> PyResult<()> {
@@ -271,6 +278,10 @@ macro_rules! function_type_serializer {
271278
fn get_name(&self) -> &str {
272279
&self.name
273280
}
281+
282+
fn retry_with_lax_check(&self) -> bool {
283+
self.retry_with_lax_check()
284+
}
274285
}
275286
};
276287
}
@@ -409,6 +420,10 @@ impl FunctionWrapSerializer {
409420
fn get_fallback_serializer(&self) -> &CombinedSerializer {
410421
self.serializer.as_ref()
411422
}
423+
424+
fn retry_with_lax_check(&self) -> bool {
425+
self.serializer.retry_with_lax_check() || self.return_serializer.retry_with_lax_check()
426+
}
412427
}
413428

414429
impl_py_gc_traverse!(FunctionWrapSerializer {

tests/serializers/test_union.py

Lines changed: 93 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -62,53 +62,36 @@ def __init__(self, c, d):
6262
@pytest.fixture(scope='module')
6363
def model_serializer() -> SchemaSerializer:
6464
return SchemaSerializer(
65-
{
66-
'type': 'union',
67-
'choices': [
68-
{
69-
'type': 'model',
70-
'cls': ModelA,
71-
'schema': {
72-
'type': 'model-fields',
73-
'fields': {
74-
'a': {'type': 'model-field', 'schema': {'type': 'bytes'}},
75-
'b': {
76-
'type': 'model-field',
77-
'schema': {
78-
'type': 'float',
79-
'serialization': {
80-
'type': 'format',
81-
'formatting_string': '0.1f',
82-
'when_used': 'unless-none',
83-
},
84-
},
85-
},
86-
},
87-
},
88-
},
89-
{
90-
'type': 'model',
91-
'cls': ModelB,
92-
'schema': {
93-
'type': 'model-fields',
94-
'fields': {
95-
'c': {'type': 'model-field', 'schema': {'type': 'bytes'}},
96-
'd': {
97-
'type': 'model-field',
98-
'schema': {
99-
'type': 'float',
100-
'serialization': {
101-
'type': 'format',
102-
'formatting_string': '0.2f',
103-
'when_used': 'unless-none',
104-
},
105-
},
106-
},
107-
},
108-
},
109-
},
65+
core_schema.union_schema(
66+
[
67+
core_schema.model_schema(
68+
ModelA,
69+
core_schema.model_fields_schema(
70+
{
71+
'a': core_schema.model_field(core_schema.bytes_schema()),
72+
'b': core_schema.model_field(
73+
core_schema.float_schema(
74+
serialization=core_schema.format_ser_schema('0.1f', when_used='unless-none')
75+
)
76+
),
77+
}
78+
),
79+
),
80+
core_schema.model_schema(
81+
ModelB,
82+
core_schema.model_fields_schema(
83+
{
84+
'c': core_schema.model_field(core_schema.bytes_schema()),
85+
'd': core_schema.model_field(
86+
core_schema.float_schema(
87+
serialization=core_schema.format_ser_schema('0.2f', when_used='unless-none')
88+
)
89+
),
90+
}
91+
),
92+
),
11093
],
111-
}
94+
)
11295
)
11396

11497

@@ -782,6 +765,70 @@ class ModelB:
782765
assert s.to_python(model_b) == {'field': 1, 'TAG': 'b'}
783766

784767

768+
def test_union_model_wrap_serializer():
769+
def wrap_serializer(value, handler):
770+
return handler(value)
771+
772+
class Data:
773+
pass
774+
775+
class ModelA:
776+
a: Data
777+
778+
class ModelB:
779+
a: Data
780+
781+
model_serializer = SchemaSerializer(
782+
core_schema.union_schema(
783+
[
784+
core_schema.model_schema(
785+
ModelA,
786+
core_schema.model_fields_schema(
787+
{
788+
'a': core_schema.model_field(
789+
core_schema.model_schema(
790+
Data,
791+
core_schema.model_fields_schema({}),
792+
)
793+
),
794+
},
795+
),
796+
serialization=core_schema.wrap_serializer_function_ser_schema(wrap_serializer),
797+
),
798+
core_schema.model_schema(
799+
ModelB,
800+
core_schema.model_fields_schema(
801+
{
802+
'a': core_schema.model_field(
803+
core_schema.model_schema(
804+
Data,
805+
core_schema.model_fields_schema({}),
806+
)
807+
),
808+
},
809+
),
810+
serialization=core_schema.wrap_serializer_function_ser_schema(wrap_serializer),
811+
),
812+
],
813+
)
814+
)
815+
816+
input_value = ModelA()
817+
input_value.a = Data()
818+
819+
assert model_serializer.to_python(input_value) == {'a': {}}
820+
assert model_serializer.to_python(input_value, mode='json') == {'a': {}}
821+
assert model_serializer.to_json(input_value) == b'{"a":{}}'
822+
823+
# add some additional attribute, should be ignored & not break serialization
824+
825+
input_value.a._a = 'foo'
826+
827+
assert model_serializer.to_python(input_value) == {'a': {}}
828+
assert model_serializer.to_python(input_value, mode='json') == {'a': {}}
829+
assert model_serializer.to_json(input_value) == b'{"a":{}}'
830+
831+
785832
class ModelDog:
786833
def __init__(self, type_: Literal['dog']) -> None:
787834
self.type_ = 'dog'

0 commit comments

Comments
 (0)