Skip to content

Commit d3f97e7

Browse files
authored
simplify serialization filtering (#578)
1 parent 9ba8acd commit d3f97e7

File tree

5 files changed

+99
-125
lines changed

5 files changed

+99
-125
lines changed

src/serializers/filter.rs

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::hash::Hash;
44
use pyo3::exceptions::PyTypeError;
55
use pyo3::intern;
66
use pyo3::prelude::*;
7-
use pyo3::types::{PyBool, PyDict, PySet, PyString};
7+
use pyo3::types::{PyBool, PyDict, PySet};
88

99
use crate::build_tools::SchemaDict;
1010

@@ -63,19 +63,6 @@ impl SchemaFilter<isize> {
6363
Ok(Self { include, exclude })
6464
}
6565

66-
pub fn from_vec_hash(py: Python, exclude: Vec<Py<PyString>>) -> PyResult<Self> {
67-
let exclude = if exclude.is_empty() {
68-
None
69-
} else {
70-
let mut set: AHashSet<isize> = AHashSet::with_capacity(exclude.len());
71-
for item in exclude {
72-
set.insert(item.as_ref(py).hash()?);
73-
}
74-
Some(set)
75-
};
76-
Ok(Self { include: None, exclude })
77-
}
78-
7966
fn build_set_hashes(v: Option<&PyAny>) -> PyResult<Option<AHashSet<isize>>> {
8067
match v {
8168
Some(value) => {

src/serializers/type_serializers/dataclass.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ use crate::build_tools::{py_error_type, SchemaDict};
88
use crate::definitions::DefinitionsBuilder;
99

1010
use super::model::ModelSerializer;
11-
use super::typed_dict::{TypedDictField, TypedDictSerializer};
12-
use super::{BuildSerializer, CombinedSerializer, ComputedFields, SchemaFilter};
11+
use super::typed_dict::{FieldSerializer, TypedDictSerializer};
12+
use super::{BuildSerializer, CombinedSerializer, ComputedFields};
1313

1414
pub struct DataclassArgsBuilder;
1515

@@ -24,8 +24,7 @@ impl BuildSerializer for DataclassArgsBuilder {
2424
let py = schema.py();
2525

2626
let fields_list: &PyList = schema.get_as_req(intern!(py, "fields"))?;
27-
let mut fields: AHashMap<String, TypedDictField> = AHashMap::with_capacity(fields_list.len());
28-
let mut exclude: Vec<Py<PyString>> = Vec::with_capacity(fields_list.len());
27+
let mut fields: AHashMap<String, FieldSerializer> = AHashMap::with_capacity(fields_list.len());
2928

3029
for (index, item) in fields_list.iter().enumerate() {
3130
let field_info: &PyDict = item.downcast()?;
@@ -34,21 +33,20 @@ impl BuildSerializer for DataclassArgsBuilder {
3433
let key_py: Py<PyString> = PyString::intern(py, &name).into_py(py);
3534

3635
if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
37-
exclude.push(key_py.clone_ref(py));
36+
fields.insert(name, FieldSerializer::new(py, key_py, None, None, true));
3837
} else {
3938
let schema = field_info.get_as_req(intern!(py, "schema"))?;
4039
let serializer = CombinedSerializer::build(schema, config, definitions)
4140
.map_err(|e| py_error_type!("Field `{}`:\n {}", index, e))?;
4241

4342
let alias = field_info.get_as(intern!(py, "serialization_alias"))?;
44-
fields.insert(name, TypedDictField::new(py, key_py, alias, serializer, true));
43+
fields.insert(name, FieldSerializer::new(py, key_py, alias, Some(serializer), true));
4544
}
4645
}
4746

48-
let filter = SchemaFilter::from_vec_hash(py, exclude)?;
4947
let computed_fields = ComputedFields::new(schema)?;
5048

51-
Ok(TypedDictSerializer::new(fields, false, filter, computed_fields).into())
49+
Ok(TypedDictSerializer::new(fields, false, computed_fields).into())
5250
}
5351
}
5452

src/serializers/type_serializers/model.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@ use crate::build_tools::{py_error_type, ExtraBehavior, SchemaDict};
1010
use crate::definitions::DefinitionsBuilder;
1111
use crate::serializers::computed_fields::ComputedFields;
1212
use crate::serializers::extra::SerCheck;
13-
use crate::serializers::filter::SchemaFilter;
1413
use crate::serializers::infer::{infer_serialize, infer_to_python};
1514
use crate::serializers::ob_type::ObType;
16-
use crate::serializers::type_serializers::typed_dict::{TypedDictField, TypedDictSerializer};
15+
use crate::serializers::type_serializers::typed_dict::{FieldSerializer, TypedDictSerializer};
1716

1817
use super::{
1918
infer_json_key, infer_json_key_known, object_to_dict, py_err_se_err, BuildSerializer, CombinedSerializer, Extra,
@@ -38,8 +37,7 @@ impl BuildSerializer for ModelFieldsBuilder {
3837
);
3938

4039
let fields_dict: &PyDict = schema.get_as_req(intern!(py, "fields"))?;
41-
let mut fields: AHashMap<String, TypedDictField> = AHashMap::with_capacity(fields_dict.len());
42-
let mut exclude: Vec<Py<PyString>> = Vec::with_capacity(fields_dict.len());
40+
let mut fields: AHashMap<String, FieldSerializer> = AHashMap::with_capacity(fields_dict.len());
4341

4442
for (key, value) in fields_dict.iter() {
4543
let key_py: &PyString = key.downcast()?;
@@ -49,22 +47,21 @@ impl BuildSerializer for ModelFieldsBuilder {
4947
let key_py: Py<PyString> = key_py.into_py(py);
5048

5149
if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
52-
exclude.push(key_py.clone_ref(py));
50+
fields.insert(key, FieldSerializer::new(py, key_py, None, None, true));
5351
} else {
5452
let alias: Option<String> = field_info.get_as(intern!(py, "serialization_alias"))?;
5553

5654
let schema = field_info.get_as_req(intern!(py, "schema"))?;
5755
let serializer = CombinedSerializer::build(schema, config, definitions)
5856
.map_err(|e| py_error_type!("Field `{}`:\n {}", key, e))?;
5957

60-
fields.insert(key, TypedDictField::new(py, key_py, alias, serializer, true));
58+
fields.insert(key, FieldSerializer::new(py, key_py, alias, Some(serializer), true));
6159
}
6260
}
6361

64-
let filter = SchemaFilter::from_vec_hash(py, exclude)?;
6562
let computed_fields = ComputedFields::new(schema)?;
6663

67-
Ok(TypedDictSerializer::new(fields, include_extra, filter, computed_fields).into())
64+
Ok(TypedDictSerializer::new(fields, include_extra, computed_fields).into())
6865
}
6966
}
7067

src/serializers/type_serializers/typed_dict.rs

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,23 @@ use super::{
1616
ComputedFields, Extra, PydanticSerializer, SchemaFilter, SerializeInfer, TypeSerializer,
1717
};
1818

19+
/// representation of a field for serialization, used by `TypedDictSerializer` and `ModelFieldsSerializer`
1920
#[derive(Debug, Clone)]
20-
pub(super) struct TypedDictField {
21+
pub(super) struct FieldSerializer {
2122
key_py: Py<PyString>,
2223
alias: Option<String>,
2324
alias_py: Option<Py<PyString>>,
24-
serializer: CombinedSerializer,
25+
// None serializer means exclude
26+
serializer: Option<CombinedSerializer>,
2527
required: bool,
2628
}
2729

28-
impl TypedDictField {
30+
impl FieldSerializer {
2931
pub(super) fn new(
3032
py: Python,
3133
key_py: Py<PyString>,
3234
alias: Option<String>,
33-
serializer: CombinedSerializer,
35+
serializer: Option<CombinedSerializer>,
3436
required: bool,
3537
) -> Self {
3638
let alias_py = alias.as_ref().map(|alias| PyString::new(py, alias.as_str()).into());
@@ -64,10 +66,10 @@ impl TypedDictField {
6466

6567
#[derive(Debug, Clone)]
6668
pub struct TypedDictSerializer {
67-
fields: AHashMap<String, TypedDictField>,
69+
fields: AHashMap<String, FieldSerializer>,
6870
computed_fields: Option<ComputedFields>,
6971
include_extra: bool,
70-
// isize because we look up include exclude via `.hash()` which returns an isize
72+
// isize because we look up filter via `.hash()` which returns an isize
7173
filter: SchemaFilter<isize>,
7274
}
7375

@@ -90,56 +92,44 @@ impl BuildSerializer for TypedDictSerializer {
9092
);
9193

9294
let fields_dict: &PyDict = schema.get_as_req(intern!(py, "fields"))?;
93-
let mut fields: AHashMap<String, TypedDictField> = AHashMap::with_capacity(fields_dict.len());
94-
let mut exclude: Vec<Py<PyString>> = Vec::with_capacity(fields_dict.len());
95+
let mut fields: AHashMap<String, FieldSerializer> = AHashMap::with_capacity(fields_dict.len());
9596

9697
for (key, value) in fields_dict.iter() {
9798
let key_py: &PyString = key.downcast()?;
9899
let key: String = key_py.extract()?;
99100
let field_info: &PyDict = value.downcast()?;
100101

101102
let key_py: Py<PyString> = key_py.into_py(py);
103+
let required = field_info.get_as(intern!(py, "required"))?.unwrap_or(total);
102104

103105
if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
104-
exclude.push(key_py.clone_ref(py));
106+
fields.insert(key, FieldSerializer::new(py, key_py, None, None, required));
105107
} else {
106108
let alias: Option<String> = field_info.get_as(intern!(py, "serialization_alias"))?;
107109

108110
let schema = field_info.get_as_req(intern!(py, "schema"))?;
109111
let serializer = CombinedSerializer::build(schema, config, definitions)
110112
.map_err(|e| py_error_type!("Field `{}`:\n {}", key, e))?;
111-
112-
fields.insert(
113-
key,
114-
TypedDictField::new(
115-
py,
116-
key_py,
117-
alias,
118-
serializer,
119-
field_info.get_as(intern!(py, "required"))?.unwrap_or(total),
120-
),
121-
);
113+
fields.insert(key, FieldSerializer::new(py, key_py, alias, Some(serializer), required));
122114
}
123115
}
124116

125-
let filter = SchemaFilter::from_vec_hash(py, exclude)?;
126117
let computed_fields = ComputedFields::new(schema)?;
127118

128-
Ok(Self::new(fields, include_extra, filter, computed_fields).into())
119+
Ok(Self::new(fields, include_extra, computed_fields).into())
129120
}
130121
}
131122

132123
impl TypedDictSerializer {
133124
pub(super) fn new(
134-
fields: AHashMap<String, TypedDictField>,
125+
fields: AHashMap<String, FieldSerializer>,
135126
include_extra: bool,
136-
filter: SchemaFilter<isize>,
137127
computed_fields: Option<ComputedFields>,
138128
) -> Self {
139129
Self {
140130
fields,
141131
include_extra,
142-
filter,
132+
filter: SchemaFilter::default(),
143133
computed_fields,
144134
}
145135
}
@@ -151,9 +141,9 @@ impl TypedDictSerializer {
151141
}
152142
}
153143

154-
fn exclude_default(&self, value: &PyAny, extra: &Extra, field: &TypedDictField) -> PyResult<bool> {
144+
fn exclude_default(&self, value: &PyAny, extra: &Extra, serializer: &CombinedSerializer) -> PyResult<bool> {
155145
if extra.exclude_defaults {
156-
if let Some(default) = field.serializer.get_default(value.py())? {
146+
if let Some(default) = serializer.get_default(value.py())? {
157147
if value.eq(default)? {
158148
return Ok(true);
159149
}
@@ -190,21 +180,25 @@ impl TypeSerializer for TypedDictSerializer {
190180
};
191181

192182
for (key, value) in py_dict {
193-
let extra = Extra {
194-
field_name: Some(key.extract()?),
195-
..td_extra
196-
};
197183
if extra.exclude_none && value.is_none() {
198184
continue;
199185
}
200186
if let Some((next_include, next_exclude)) = self.filter.key_filter(key, include, exclude)? {
187+
let extra = Extra {
188+
field_name: Some(key.extract()?),
189+
..td_extra
190+
};
201191
if let Ok(key_py_str) = key.downcast::<PyString>() {
202192
let key_str = key_py_str.to_str()?;
203193
if let Some(field) = self.fields.get(key_str) {
204-
if self.exclude_default(value, &extra, field)? {
194+
let serializer = match field.serializer {
195+
Some(ref serializer) => serializer,
196+
None => continue,
197+
};
198+
if self.exclude_default(value, &extra, serializer)? {
205199
continue;
206200
}
207-
let value = field.serializer.to_python(value, next_include, next_exclude, &extra)?;
201+
let value = serializer.to_python(value, next_include, next_exclude, &extra)?;
208202
let output_key = field.get_key_py(py, &extra);
209203
output_dict.set_item(output_key, value)?;
210204

@@ -289,17 +283,15 @@ impl TypeSerializer for TypedDictSerializer {
289283
if let Ok(key_py_str) = key.downcast::<PyString>() {
290284
let key_str = key_py_str.to_str().map_err(py_err_se_err)?;
291285
if let Some(field) = self.fields.get(key_str) {
292-
if self.exclude_default(value, &extra, field).map_err(py_err_se_err)? {
286+
let serializer = match field.serializer {
287+
Some(ref serializer) => serializer,
288+
None => continue,
289+
};
290+
if self.exclude_default(value, &extra, serializer).map_err(py_err_se_err)? {
293291
continue;
294292
}
295293
let output_key = field.get_key_json(key_str, &extra);
296-
let s = PydanticSerializer::new(
297-
value,
298-
&field.serializer,
299-
next_include,
300-
next_exclude,
301-
&extra,
302-
);
294+
let s = PydanticSerializer::new(value, serializer, next_include, next_exclude, &extra);
303295
map.serialize_entry(&output_key, &s)?;
304296
continue;
305297
}

0 commit comments

Comments
 (0)