Skip to content

Commit 5a1385b

Browse files
authored
dataclass serialization speedups (#1162)
1 parent e1cb0eb commit 5a1385b

File tree

5 files changed

+362
-187
lines changed

5 files changed

+362
-187
lines changed

src/serializers/fields.rs

Lines changed: 162 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,15 @@ pub struct GeneralFieldsSerializer {
100100
required_fields: usize,
101101
}
102102

103+
macro_rules! option_length {
104+
($op_has_len:expr) => {
105+
match $op_has_len {
106+
Some(ref has_len) => has_len.len(),
107+
None => 0,
108+
}
109+
};
110+
}
111+
103112
impl GeneralFieldsSerializer {
104113
pub(super) fn new(
105114
fields: AHashMap<String, SerField>,
@@ -136,50 +145,21 @@ impl GeneralFieldsSerializer {
136145
}
137146
}
138147
}
139-
}
140-
141-
macro_rules! option_length {
142-
($op_has_len:expr) => {
143-
match $op_has_len {
144-
Some(ref has_len) => has_len.len(),
145-
None => 0,
146-
}
147-
};
148-
}
149148

150-
impl_py_gc_traverse!(GeneralFieldsSerializer {
151-
fields,
152-
computed_fields
153-
});
154-
155-
impl TypeSerializer for GeneralFieldsSerializer {
156-
fn to_python(
149+
pub fn main_to_python<'py>(
157150
&self,
158-
value: &PyAny,
159-
include: Option<&PyAny>,
160-
exclude: Option<&PyAny>,
161-
extra: &Extra,
162-
) -> PyResult<PyObject> {
163-
let py = value.py();
164-
// If there is already a model registered (from a dataclass, BaseModel)
165-
// then do not touch it
166-
// If there is no model, we (a TypedDict) are the model
167-
let td_extra = Extra {
168-
model: extra.model.map_or_else(|| Some(value), Some),
169-
..*extra
170-
};
171-
let (main_dict, extra_dict) = if let Some(main_extra_dict) = self.extract_dicts(value) {
172-
main_extra_dict
173-
} else {
174-
td_extra.warnings.on_fallback_py(self.get_name(), value, &td_extra)?;
175-
return infer_to_python(value, include, exclude, &td_extra);
176-
};
177-
151+
py: Python<'py>,
152+
main_iter: impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>>,
153+
include: Option<&'py PyAny>,
154+
exclude: Option<&'py PyAny>,
155+
extra: Extra,
156+
) -> PyResult<&'py PyDict> {
178157
let output_dict = PyDict::new(py);
179158
let mut used_req_fields: usize = 0;
180159

181160
// NOTE! we maintain the order of the input dict assuming that's right
182-
for (key, value) in main_dict {
161+
for result in main_iter {
162+
let (key, value) = result?;
183163
let key_str = key_str(key)?;
184164
let op_field = self.fields.get(key_str);
185165
if extra.exclude_none && value.is_none() {
@@ -190,16 +170,16 @@ impl TypeSerializer for GeneralFieldsSerializer {
190170
}
191171
continue;
192172
}
193-
let extra = Extra {
173+
let field_extra = Extra {
194174
field_name: Some(key_str),
195-
..td_extra
175+
..extra
196176
};
197177
if let Some((next_include, next_exclude)) = self.filter.key_filter(key, include, exclude)? {
198178
if let Some(field) = op_field {
199179
if let Some(ref serializer) = field.serializer {
200-
if !exclude_default(value, &extra, serializer)? {
201-
let value = serializer.to_python(value, next_include, next_exclude, &extra)?;
202-
let output_key = field.get_key_py(output_dict.py(), &extra);
180+
if !exclude_default(value, &field_extra, serializer)? {
181+
let value = serializer.to_python(value, next_include, next_exclude, &field_extra)?;
182+
let output_key = field.get_key_py(output_dict.py(), &field_extra);
203183
output_dict.set_item(output_key, value)?;
204184
}
205185
}
@@ -209,23 +189,140 @@ impl TypeSerializer for GeneralFieldsSerializer {
209189
}
210190
} else if self.mode == FieldsMode::TypedDictAllow {
211191
let value = match &self.extra_serializer {
212-
Some(serializer) => serializer.to_python(value, next_include, next_exclude, &extra)?,
213-
None => infer_to_python(value, next_include, next_exclude, &extra)?,
192+
Some(serializer) => serializer.to_python(value, next_include, next_exclude, &field_extra)?,
193+
None => infer_to_python(value, next_include, next_exclude, &field_extra)?,
214194
};
215195
output_dict.set_item(key, value)?;
216-
} else if extra.check == SerCheck::Strict {
196+
} else if field_extra.check == SerCheck::Strict {
217197
return Err(PydanticSerializationUnexpectedValue::new_err(None));
218198
}
219199
}
220200
}
221-
if td_extra.check.enabled()
201+
202+
if extra.check.enabled()
222203
// If any of these are true we can't count fields
223204
&& !(extra.exclude_defaults || extra.exclude_unset || extra.exclude_none)
224205
// Check for missing fields, we can't have extra fields here
225206
&& self.required_fields > used_req_fields
226207
{
227-
return Err(PydanticSerializationUnexpectedValue::new_err(None));
208+
Err(PydanticSerializationUnexpectedValue::new_err(None))
209+
} else {
210+
Ok(output_dict)
211+
}
212+
}
213+
214+
pub fn main_serde_serialize<'py, S: serde::ser::Serializer>(
215+
&self,
216+
main_iter: impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>>,
217+
expected_len: usize,
218+
serializer: S,
219+
include: Option<&'py PyAny>,
220+
exclude: Option<&'py PyAny>,
221+
extra: Extra,
222+
) -> Result<S::SerializeMap, S::Error> {
223+
// NOTE! As above, we maintain the order of the input dict assuming that's right
224+
// we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used
225+
let mut map = serializer.serialize_map(Some(expected_len))?;
226+
227+
for result in main_iter {
228+
let (key, value) = result.map_err(py_err_se_err)?;
229+
if extra.exclude_none && value.is_none() {
230+
continue;
231+
}
232+
let key_str = key_str(key).map_err(py_err_se_err)?;
233+
let field_extra = Extra {
234+
field_name: Some(key_str),
235+
..extra
236+
};
237+
238+
let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?;
239+
if let Some((next_include, next_exclude)) = filter {
240+
if let Some(field) = self.fields.get(key_str) {
241+
if let Some(ref serializer) = field.serializer {
242+
if !exclude_default(value, &field_extra, serializer).map_err(py_err_se_err)? {
243+
let s =
244+
PydanticSerializer::new(value, serializer, next_include, next_exclude, &field_extra);
245+
let output_key = field.get_key_json(key_str, &field_extra);
246+
map.serialize_entry(&output_key, &s)?;
247+
}
248+
}
249+
} else if self.mode == FieldsMode::TypedDictAllow {
250+
let output_key = infer_json_key(key, &field_extra).map_err(py_err_se_err)?;
251+
let s = SerializeInfer::new(value, next_include, next_exclude, &field_extra);
252+
map.serialize_entry(&output_key, &s)?;
253+
}
254+
// no error case here since unions (which need the error case) use `to_python(..., mode='json')`
255+
}
228256
}
257+
Ok(map)
258+
}
259+
260+
pub fn add_computed_fields_python(
261+
&self,
262+
model: Option<&PyAny>,
263+
output_dict: &PyDict,
264+
include: Option<&PyAny>,
265+
exclude: Option<&PyAny>,
266+
extra: &Extra,
267+
) -> PyResult<()> {
268+
if let Some(ref computed_fields) = self.computed_fields {
269+
if let Some(model_value) = model {
270+
let cf_extra = Extra { model, ..*extra };
271+
computed_fields.to_python(model_value, output_dict, &self.filter, include, exclude, &cf_extra)?;
272+
}
273+
}
274+
Ok(())
275+
}
276+
277+
pub fn add_computed_fields_json<S: serde::ser::Serializer>(
278+
&self,
279+
model: Option<&PyAny>,
280+
map: &mut S::SerializeMap,
281+
include: Option<&PyAny>,
282+
exclude: Option<&PyAny>,
283+
extra: &Extra,
284+
) -> Result<(), S::Error> {
285+
if let Some(ref computed_fields) = self.computed_fields {
286+
if let Some(model) = model {
287+
computed_fields.serde_serialize::<S>(model, map, &self.filter, include, exclude, extra)?;
288+
}
289+
}
290+
Ok(())
291+
}
292+
293+
pub fn computed_field_count(&self) -> usize {
294+
option_length!(self.computed_fields)
295+
}
296+
}
297+
298+
impl_py_gc_traverse!(GeneralFieldsSerializer {
299+
fields,
300+
computed_fields
301+
});
302+
303+
impl TypeSerializer for GeneralFieldsSerializer {
304+
fn to_python(
305+
&self,
306+
value: &PyAny,
307+
include: Option<&PyAny>,
308+
exclude: Option<&PyAny>,
309+
extra: &Extra,
310+
) -> PyResult<PyObject> {
311+
let py = value.py();
312+
// If there is already a model registered (from a dataclass, BaseModel)
313+
// then do not touch it
314+
// If there is no model, we (a TypedDict) are the model
315+
let model = extra.model.map_or_else(|| Some(value), Some);
316+
let td_extra = Extra { model, ..*extra };
317+
let (main_dict, extra_dict) = if let Some(main_extra_dict) = self.extract_dicts(value) {
318+
main_extra_dict
319+
} else {
320+
td_extra.warnings.on_fallback_py(self.get_name(), value, &td_extra)?;
321+
return infer_to_python(value, include, exclude, &td_extra);
322+
};
323+
324+
let output_dict = self.main_to_python(py, main_dict.iter().map(Ok), include, exclude, td_extra)?;
325+
229326
// this is used to include `__pydantic_extra__` in serialization on models
230327
if let Some(extra_dict) = extra_dict {
231328
for (key, value) in extra_dict {
@@ -241,11 +338,7 @@ impl TypeSerializer for GeneralFieldsSerializer {
241338
}
242339
}
243340
}
244-
if let Some(ref computed_fields) = self.computed_fields {
245-
if let Some(model) = td_extra.model {
246-
computed_fields.to_python(model, output_dict, &self.filter, include, exclude, &td_extra)?;
247-
}
248-
}
341+
self.add_computed_fields_python(model, output_dict, include, exclude, extra)?;
249342
Ok(output_dict.into_py(py))
250343
}
251344

@@ -271,46 +364,23 @@ impl TypeSerializer for GeneralFieldsSerializer {
271364
// If there is already a model registered (from a dataclass, BaseModel)
272365
// then do not touch it
273366
// If there is no model, we (a TypedDict) are the model
274-
let td_extra = Extra {
275-
model: extra.model.map_or_else(|| Some(value), Some),
276-
..*extra
277-
};
367+
let model = extra.model.map_or_else(|| Some(value), Some);
368+
let td_extra = Extra { model, ..*extra };
278369
let expected_len = match self.mode {
279-
FieldsMode::TypedDictAllow => main_dict.len() + option_length!(self.computed_fields),
280-
_ => self.fields.len() + option_length!(extra_dict) + option_length!(self.computed_fields),
370+
FieldsMode::TypedDictAllow => main_dict.len() + self.computed_field_count(),
371+
_ => self.fields.len() + option_length!(extra_dict) + self.computed_field_count(),
281372
};
282373
// NOTE! As above, we maintain the order of the input dict assuming that's right
283374
// we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used
284-
let mut map = serializer.serialize_map(Some(expected_len))?;
285-
286-
for (key, value) in main_dict {
287-
if extra.exclude_none && value.is_none() {
288-
continue;
289-
}
290-
let key_str = key_str(key).map_err(py_err_se_err)?;
291-
let extra = Extra {
292-
field_name: Some(key_str),
293-
..td_extra
294-
};
375+
let mut map = self.main_serde_serialize(
376+
main_dict.iter().map(Ok),
377+
expected_len,
378+
serializer,
379+
include,
380+
exclude,
381+
td_extra,
382+
)?;
295383

296-
let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?;
297-
if let Some((next_include, next_exclude)) = filter {
298-
if let Some(field) = self.fields.get(key_str) {
299-
if let Some(ref serializer) = field.serializer {
300-
if !exclude_default(value, &extra, serializer).map_err(py_err_se_err)? {
301-
let s = PydanticSerializer::new(value, serializer, next_include, next_exclude, &extra);
302-
let output_key = field.get_key_json(key_str, &extra);
303-
map.serialize_entry(&output_key, &s)?;
304-
}
305-
}
306-
} else if self.mode == FieldsMode::TypedDictAllow {
307-
let output_key = infer_json_key(key, &extra).map_err(py_err_se_err)?;
308-
let s = SerializeInfer::new(value, next_include, next_exclude, &extra);
309-
map.serialize_entry(&output_key, &s)?;
310-
}
311-
// no error case here since unions (which need the error case) use `to_python(..., mode='json')`
312-
}
313-
}
314384
// this is used to include `__pydantic_extra__` in serialization on models
315385
if let Some(extra_dict) = extra_dict {
316386
for (key, value) in extra_dict {
@@ -319,17 +389,14 @@ impl TypeSerializer for GeneralFieldsSerializer {
319389
}
320390
let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?;
321391
if let Some((next_include, next_exclude)) = filter {
322-
let output_key = infer_json_key(key, &td_extra).map_err(py_err_se_err)?;
323-
let s = SerializeInfer::new(value, next_include, next_exclude, &td_extra);
392+
let output_key = infer_json_key(key, extra).map_err(py_err_se_err)?;
393+
let s = SerializeInfer::new(value, next_include, next_exclude, extra);
324394
map.serialize_entry(&output_key, &s)?;
325395
}
326396
}
327397
}
328-
if let Some(ref computed_fields) = self.computed_fields {
329-
if let Some(model) = td_extra.model {
330-
computed_fields.serde_serialize::<S>(model, &mut map, &self.filter, include, exclude, &td_extra)?;
331-
}
332-
}
398+
399+
self.add_computed_fields_json::<S>(model, &mut map, include, exclude, extra)?;
333400
map.end()
334401
}
335402

0 commit comments

Comments
 (0)