Skip to content

Commit 3e1fb94

Browse files
committed
improve field handling
1 parent ad7974e commit 3e1fb94

File tree

8 files changed

+156
-100
lines changed

8 files changed

+156
-100
lines changed

pydantic_core/core_schema.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2841,6 +2841,7 @@ class ModelSchema(TypedDict, total=False):
28412841
revalidate_instances: Literal['always', 'never', 'subclass-instances'] # default: 'never'
28422842
strict: bool
28432843
frozen: bool
2844+
extra_behavior: ExtraBehavior
28442845
config: CoreConfig
28452846
ref: str
28462847
metadata: Any
@@ -2855,6 +2856,7 @@ def model_schema(
28552856
revalidate_instances: Literal['always', 'never', 'subclass-instances'] | None = None,
28562857
strict: bool | None = None,
28572858
frozen: bool | None = None,
2859+
extra_behavior: ExtraBehavior | None = None,
28582860
config: CoreConfig | None = None,
28592861
ref: str | None = None,
28602862
metadata: Any = None,
@@ -2894,6 +2896,7 @@ class MyModel:
28942896
should re-validate defaults to config.revalidate_instances, else 'never'
28952897
strict: Whether the model is strict
28962898
frozen: Whether the model is frozen
2899+
extra_behavior: The extra behavior to use for the model, used in serialization
28972900
config: The config to use for the model
28982901
ref: optional unique identifier of the schema, used to reference the schema in other places
28992902
metadata: Any other information you want to include with the schema, not used by pydantic-core
@@ -2907,6 +2910,7 @@ class MyModel:
29072910
revalidate_instances=revalidate_instances,
29082911
strict=strict,
29092912
frozen=frozen,
2913+
extra_behavior=extra_behavior,
29102914
config=config,
29112915
ref=ref,
29122916
metadata=metadata,

src/serializers/fields.rs

Lines changed: 76 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -62,24 +62,6 @@ impl SerField {
6262
}
6363
Cow::Borrowed(key_str)
6464
}
65-
66-
pub fn to_python(
67-
&self,
68-
output_dict: &PyDict,
69-
value: &PyAny,
70-
next_include: Option<&PyAny>,
71-
next_exclude: Option<&PyAny>,
72-
extra: &Extra,
73-
) -> PyResult<()> {
74-
if let Some(ref serializer) = self.serializer {
75-
if !exclude_default(value, extra, serializer)? {
76-
let value = serializer.to_python(value, next_include, next_exclude, extra)?;
77-
let output_key = self.get_key_py(output_dict.py(), extra);
78-
output_dict.set_item(output_key, value)?;
79-
}
80-
}
81-
Ok(())
82-
}
8365
}
8466

8567
fn exclude_default(value: &PyAny, extra: &Extra, serializer: &CombinedSerializer) -> PyResult<bool> {
@@ -93,12 +75,22 @@ fn exclude_default(value: &PyAny, extra: &Extra, serializer: &CombinedSerializer
9375
Ok(false)
9476
}
9577

78+
#[derive(Debug, Clone)]
79+
pub(super) enum FieldsMode {
80+
// typeddict with no extra items
81+
SimpleDict,
82+
// a model - get `__dict__` and `__pydantic_extra__` - `GeneralFieldsSerializer` will get a tuple
83+
ModelExtra,
84+
// typeddict with extra items - one dict with extra items
85+
TypedDictAllow,
86+
}
87+
9688
/// General purpose serializer for fields - used by dataclasses, models and typed_dicts
9789
#[derive(Debug, Clone)]
9890
pub struct GeneralFieldsSerializer {
9991
fields: AHashMap<String, SerField>,
10092
computed_fields: Option<ComputedFields>,
101-
include_extra: bool,
93+
mode: FieldsMode,
10294
// isize because we look up filter via `.hash()` which returns an isize
10395
filter: SchemaFilter<isize>,
10496
required_fields: usize,
@@ -107,26 +99,39 @@ pub struct GeneralFieldsSerializer {
10799
impl GeneralFieldsSerializer {
108100
pub(super) fn new(
109101
fields: AHashMap<String, SerField>,
110-
include_extra: bool,
102+
mode: FieldsMode,
111103
computed_fields: Option<ComputedFields>,
112104
) -> Self {
113105
let required_fields = fields.values().filter(|f| f.required).count();
114106
Self {
115107
fields,
116-
include_extra,
108+
mode,
117109
filter: SchemaFilter::default(),
118110
computed_fields,
119111
required_fields,
120112
}
121113
}
122114

123115
fn extract_dicts<'a>(&self, value: &'a PyAny) -> Option<(&'a PyDict, Option<&'a PyDict>)> {
124-
if let Ok(main_dict) = value.downcast::<PyDict>() {
125-
Some((main_dict, None))
126-
} else if let Ok((main_dict, extra_dict)) = value.extract::<(&PyDict, &PyDict)>() {
127-
Some((main_dict, Some(extra_dict)))
128-
} else {
129-
None
116+
match self.mode {
117+
FieldsMode::ModelExtra => {
118+
if let Ok((main_dict, extra_dict)) = value.extract::<(&PyDict, &PyAny)>() {
119+
if let Ok(extra_dict) = extra_dict.downcast::<PyDict>() {
120+
Some((main_dict, Some(extra_dict)))
121+
} else {
122+
Some((main_dict, None))
123+
}
124+
} else {
125+
None
126+
}
127+
}
128+
_ => {
129+
if let Ok(main_dict) = value.downcast::<PyDict>() {
130+
Some((main_dict, None))
131+
} else {
132+
None
133+
}
134+
}
130135
}
131136
}
132137
}
@@ -163,32 +168,39 @@ impl TypeSerializer for GeneralFieldsSerializer {
163168
return infer_to_python(value, include, exclude, &td_extra);
164169
};
165170

166-
// NOTE! we maintain the order of the input dict assuming that's right
167171
let output_dict = PyDict::new(py);
168172
let mut used_req_fields: usize = 0;
169173

174+
// NOTE! we maintain the order of the input dict assuming that's right
170175
for (key, value) in main_dict {
176+
let key_str = key_str(key)?;
177+
let op_field = self.fields.get(key_str);
171178
if extra.exclude_none && value.is_none() {
179+
if let Some(field) = op_field {
180+
if field.required {
181+
used_req_fields += 1;
182+
}
183+
}
172184
continue;
173185
}
186+
let extra = Extra {
187+
field_name: Some(key_str),
188+
..td_extra
189+
};
174190
if let Some((next_include, next_exclude)) = self.filter.key_filter(key, include, exclude)? {
175-
let extra = Extra {
176-
field_name: Some(key.extract()?),
177-
..td_extra
178-
};
179-
if let Ok(key_py_str) = key.downcast::<PyString>() {
180-
let key_str = key_py_str.to_str()?;
181-
if let Some(field) = self.fields.get(key_str) {
182-
field.to_python(output_dict, value, next_include, next_exclude, &extra)?;
183-
184-
if field.required {
185-
used_req_fields += 1;
191+
if let Some(field) = op_field {
192+
if let Some(ref serializer) = field.serializer {
193+
if !exclude_default(value, &extra, serializer)? {
194+
let value = serializer.to_python(value, next_include, next_exclude, &extra)?;
195+
let output_key = field.get_key_py(output_dict.py(), &extra);
196+
output_dict.set_item(output_key, value)?;
186197
}
187-
continue;
188198
}
189-
}
190-
if self.include_extra {
191-
// TODO test this
199+
200+
if field.required {
201+
used_req_fields += 1;
202+
}
203+
} else if matches!(self.mode, FieldsMode::TypedDictAllow) {
192204
let value = infer_to_python(value, next_include, next_exclude, &extra)?;
193205
output_dict.set_item(key, value)?;
194206
} else if extra.check.enabled() {
@@ -242,40 +254,37 @@ impl TypeSerializer for GeneralFieldsSerializer {
242254
model: extra.model.map_or_else(|| Some(value), Some),
243255
..*extra
244256
};
245-
let expected_len = match self.include_extra {
246-
true => main_dict.len() + option_length!(self.computed_fields),
247-
false => self.fields.len() + option_length!(extra_dict) + option_length!(self.computed_fields),
257+
let expected_len = match self.mode {
258+
FieldsMode::TypedDictAllow => main_dict.len() + option_length!(self.computed_fields),
259+
_ => self.fields.len() + option_length!(extra_dict) + option_length!(self.computed_fields),
248260
};
249261
// NOTE! As above, we maintain the order of the input dict assuming that's right
250262
// we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used
251263
let mut map = serializer.serialize_map(Some(expected_len))?;
252264

253265
for (key, value) in main_dict {
254-
let extra = Extra {
255-
field_name: Some(key.extract().map_err(py_err_se_err)?),
256-
..td_extra
257-
};
258266
if extra.exclude_none && value.is_none() {
259267
continue;
260268
}
269+
let key_str = key_str(key).map_err(py_err_se_err)?;
270+
let extra = Extra {
271+
field_name: Some(key_str),
272+
..td_extra
273+
};
274+
261275
let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?;
262276
if let Some((next_include, next_exclude)) = filter {
263-
if let Ok(key_py_str) = key.downcast::<PyString>() {
264-
let key_str = key_py_str.to_str().map_err(py_err_se_err)?;
265-
if let Some(field) = self.fields.get(key_str) {
266-
if let Some(ref serializer) = field.serializer {
267-
if !exclude_default(value, &extra, serializer).map_err(py_err_se_err)? {
268-
let s = PydanticSerializer::new(value, serializer, next_include, next_exclude, &extra);
269-
let output_key = field.get_key_json(key_str, &extra);
270-
map.serialize_entry(&output_key, &s)?;
271-
}
272-
continue;
277+
if let Some(field) = self.fields.get(key_str) {
278+
if let Some(ref serializer) = field.serializer {
279+
if !exclude_default(value, &extra, serializer).map_err(py_err_se_err)? {
280+
let s = PydanticSerializer::new(value, serializer, next_include, next_exclude, &extra);
281+
let output_key = field.get_key_json(key_str, &extra);
282+
map.serialize_entry(&output_key, &s)?;
273283
}
274284
}
275-
}
276-
if self.include_extra {
277-
let s = SerializeInfer::new(value, include, exclude, &extra);
285+
} else if matches!(self.mode, FieldsMode::TypedDictAllow) {
278286
let output_key = infer_json_key(key, &extra).map_err(py_err_se_err)?;
287+
let s = SerializeInfer::new(value, next_include, next_exclude, &extra);
279288
map.serialize_entry(&output_key, &s)?
280289
}
281290
}
@@ -303,3 +312,7 @@ impl TypeSerializer for GeneralFieldsSerializer {
303312
"fields"
304313
}
305314
}
315+
316+
fn key_str(key: &PyAny) -> PyResult<&str> {
317+
key.downcast::<PyString>()?.to_str()
318+
}

src/serializers/type_serializers/dataclass.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ use pyo3::types::{PyDict, PyList, PyString};
44

55
use ahash::AHashMap;
66

7-
use crate::build_tools::{py_error_type, SchemaDict};
7+
use crate::build_tools::{py_error_type, ExtraBehavior, SchemaDict};
88
use crate::definitions::DefinitionsBuilder;
99

1010
use super::model::ModelSerializer;
11-
use super::{BuildSerializer, CombinedSerializer, ComputedFields, GeneralFieldsSerializer, SerField};
11+
use super::{BuildSerializer, CombinedSerializer, ComputedFields, FieldsMode, GeneralFieldsSerializer, SerField};
1212

1313
pub struct DataclassArgsBuilder;
1414

@@ -25,6 +25,11 @@ impl BuildSerializer for DataclassArgsBuilder {
2525
let fields_list: &PyList = schema.get_as_req(intern!(py, "fields"))?;
2626
let mut fields: AHashMap<String, SerField> = AHashMap::with_capacity(fields_list.len());
2727

28+
let fields_mode = match ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)? {
29+
ExtraBehavior::Allow => FieldsMode::TypedDictAllow,
30+
_ => FieldsMode::SimpleDict,
31+
};
32+
2833
for (index, item) in fields_list.iter().enumerate() {
2934
let field_info: &PyDict = item.downcast()?;
3035
let name: String = field_info.get_as_req(intern!(py, "name"))?;
@@ -45,7 +50,7 @@ impl BuildSerializer for DataclassArgsBuilder {
4550

4651
let computed_fields = ComputedFields::new(schema)?;
4752

48-
Ok(GeneralFieldsSerializer::new(fields, false, computed_fields).into())
53+
Ok(GeneralFieldsSerializer::new(fields, fields_mode, computed_fields).into())
4954
}
5055
}
5156

src/serializers/type_serializers/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub(self) use super::computed_fields::ComputedFields;
2727
pub(self) use super::config::utf8_py_error;
2828
pub(self) use super::errors::{py_err_se_err, PydanticSerializationError};
2929
pub(self) use super::extra::{Extra, ExtraOwned, SerCheck, SerMode};
30-
pub(self) use super::fields::{GeneralFieldsSerializer, SerField};
30+
pub(self) use super::fields::{FieldsMode, GeneralFieldsSerializer, SerField};
3131
pub(self) use super::filter::{AnyFilter, SchemaFilter};
3232
pub(self) use super::infer::{
3333
infer_json_key, infer_json_key_known, infer_serialize, infer_serialize_known, infer_to_python,

src/serializers/type_serializers/model.rs

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ use crate::definitions::DefinitionsBuilder;
1111

1212
use super::{
1313
infer_json_key, infer_json_key_known, infer_serialize, infer_to_python, object_to_dict, py_err_se_err,
14-
BuildSerializer, CombinedSerializer, ComputedFields, Extra, GeneralFieldsSerializer, ObType, SerCheck, SerField,
15-
TypeSerializer,
14+
BuildSerializer, CombinedSerializer, ComputedFields, Extra, FieldsMode, GeneralFieldsSerializer, ObType, SerCheck,
15+
SerField, TypeSerializer,
1616
};
1717

1818
pub struct ModelFieldsBuilder;
@@ -27,10 +27,10 @@ impl BuildSerializer for ModelFieldsBuilder {
2727
) -> PyResult<CombinedSerializer> {
2828
let py = schema.py();
2929

30-
let include_extra = matches!(
31-
ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?,
32-
ExtraBehavior::Allow
33-
);
30+
let fields_mode = match has_extra(schema, config)? {
31+
true => FieldsMode::ModelExtra,
32+
false => FieldsMode::SimpleDict,
33+
};
3434

3535
let fields_dict: &PyDict = schema.get_as_req(intern!(py, "fields"))?;
3636
let mut fields: AHashMap<String, SerField> = AHashMap::with_capacity(fields_dict.len());
@@ -57,14 +57,15 @@ impl BuildSerializer for ModelFieldsBuilder {
5757

5858
let computed_fields = ComputedFields::new(schema)?;
5959

60-
Ok(GeneralFieldsSerializer::new(fields, include_extra, computed_fields).into())
60+
Ok(GeneralFieldsSerializer::new(fields, fields_mode, computed_fields).into())
6161
}
6262
}
6363

6464
#[derive(Debug, Clone)]
6565
pub struct ModelSerializer {
6666
class: Py<PyType>,
6767
serializer: Box<CombinedSerializer>,
68+
has_extra: bool,
6869
name: String,
6970
}
7071

@@ -84,12 +85,19 @@ impl BuildSerializer for ModelSerializer {
8485
Ok(Self {
8586
class: class.into(),
8687
serializer,
88+
has_extra: has_extra(schema, config)?,
8789
name: class.getattr(intern!(py, "__name__"))?.extract()?,
8890
}
8991
.into())
9092
}
9193
}
9294

95+
fn has_extra(schema: &PyDict, config: Option<&PyDict>) -> PyResult<bool> {
96+
let py = schema.py();
97+
let extra_behaviour = ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?;
98+
Ok(matches!(extra_behaviour, ExtraBehavior::Allow))
99+
}
100+
93101
impl ModelSerializer {
94102
fn allow_value(&self, value: &PyAny, extra: &Extra) -> PyResult<bool> {
95103
match extra.check {
@@ -103,16 +111,12 @@ impl ModelSerializer {
103111
let py = value.py();
104112
let dict = object_to_dict(value, true, extra)?;
105113

106-
let model_extra = match value.getattr(intern!(py, "__pydantic_extra__")) {
107-
Ok(model_extra) => model_extra,
108-
Err(_) => return Ok(dict),
109-
};
110-
111-
if model_extra.is_none() {
112-
Ok(dict)
113-
} else {
114+
if self.has_extra {
115+
let model_extra = value.getattr(intern!(py, "__pydantic_extra__"))?;
114116
let py_tuple = (dict, model_extra).to_object(py);
115117
Ok(py_tuple.into_ref(py))
118+
} else {
119+
Ok(dict)
116120
}
117121
}
118122
}

0 commit comments

Comments
 (0)