Skip to content

Commit a0826b1

Browse files
committed
fix dataclass validation & serialization
1 parent 66072e0 commit a0826b1

File tree

12 files changed

+344
-130
lines changed

12 files changed

+344
-130
lines changed

pydantic_core/core_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3123,7 +3123,7 @@ def dataclass_schema(
31233123
another schema, not as the root type.
31243124
31253125
Args:
3126-
cls: The dataclass type, used to to perform subclass checks
3126+
cls: The dataclass type, used to perform subclass checks
31273127
schema: The schema to use for the dataclass fields
31283128
cls_name: The name to use in error locs, etc; this is useful for generics (default: `cls.__name__`)
31293129
post_init: Whether to call `__post_init__` after validation

src/serializers/infer.rs

Lines changed: 46 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use super::errors::{py_err_se_err, PydanticSerializationError};
2121
use super::extra::{Extra, SerMode};
2222
use super::filter::AnyFilter;
2323
use super::ob_type::ObType;
24-
use super::shared::object_to_dict;
24+
use super::shared::dataclass_to_dict;
2525

2626
pub(crate) fn infer_to_python(
2727
value: &PyAny,
@@ -97,29 +97,23 @@ pub(crate) fn infer_to_python_known(
9797
Ok::<PyObject, PyErr>(new_dict.into_py(py))
9898
};
9999

100-
let serialize_with_serializer = |value: &PyAny, is_model: bool| {
101-
if let Ok(py_serializer) = value.getattr(intern!(py, "__pydantic_serializer__")) {
102-
if let Ok(serializer) = py_serializer.extract::<SchemaSerializer>() {
103-
let extra = serializer.build_extra(
104-
py,
105-
extra.mode,
106-
extra.by_alias,
107-
extra.warnings,
108-
extra.exclude_unset,
109-
extra.exclude_defaults,
110-
extra.exclude_none,
111-
extra.round_trip,
112-
extra.rec_guard,
113-
extra.serialize_unknown,
114-
extra.fallback,
115-
);
116-
return serializer.serializer.to_python(value, include, exclude, &extra);
117-
}
118-
}
119-
// Fallback to dict serialization if `__pydantic_serializer__` is not set.
120-
// This currently only affects non-pydantic dataclasses.
121-
let dict = object_to_dict(value, is_model, extra)?;
122-
serialize_dict(dict)
100+
let serialize_with_serializer = || {
101+
let py_serializer = value.getattr(intern!(py, "__pydantic_serializer__"))?;
102+
let serializer: SchemaSerializer = py_serializer.extract()?;
103+
let extra = serializer.build_extra(
104+
py,
105+
extra.mode,
106+
extra.by_alias,
107+
extra.warnings,
108+
extra.exclude_unset,
109+
extra.exclude_defaults,
110+
extra.exclude_none,
111+
extra.round_trip,
112+
extra.rec_guard,
113+
extra.serialize_unknown,
114+
extra.fallback,
115+
);
116+
serializer.serializer.to_python(value, include, exclude, &extra)
123117
};
124118

125119
let value = match extra.mode {
@@ -191,8 +185,8 @@ pub(crate) fn infer_to_python_known(
191185
let py_url: PyMultiHostUrl = value.extract()?;
192186
py_url.__str__().into_py(py)
193187
}
194-
ObType::PydanticSerializable => serialize_with_serializer(value, true)?,
195-
ObType::Dataclass => serialize_with_serializer(value, false)?,
188+
ObType::PydanticSerializable => serialize_with_serializer()?,
189+
ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?,
196190
ObType::Enum => {
197191
let v = value.getattr(intern!(py, "value"))?;
198192
infer_to_python(v, include, exclude, extra)?.into_py(py)
@@ -257,8 +251,8 @@ pub(crate) fn infer_to_python_known(
257251
}
258252
new_dict.into_py(py)
259253
}
260-
ObType::PydanticSerializable => serialize_with_serializer(value, true)?,
261-
ObType::Dataclass => serialize_with_serializer(value, false)?,
254+
ObType::PydanticSerializable => serialize_with_serializer()?,
255+
ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?,
262256
ObType::Generator => {
263257
let iter = super::type_serializers::generator::SerializationIterator::new(
264258
value.downcast()?,
@@ -406,36 +400,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
406400
}};
407401
}
408402

409-
macro_rules! serialize_with_serializer {
410-
($py_serializable:expr, $is_model:expr) => {{
411-
let py = $py_serializable.py();
412-
if let Ok(py_serializer) = value.getattr(intern!(py, "__pydantic_serializer__")) {
413-
if let Ok(extracted_serializer) = py_serializer.extract::<SchemaSerializer>() {
414-
let extra = extracted_serializer.build_extra(
415-
py,
416-
extra.mode,
417-
extra.by_alias,
418-
extra.warnings,
419-
extra.exclude_unset,
420-
extra.exclude_defaults,
421-
extra.exclude_none,
422-
extra.round_trip,
423-
extra.rec_guard,
424-
extra.serialize_unknown,
425-
extra.fallback,
426-
);
427-
let pydantic_serializer =
428-
PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra);
429-
return pydantic_serializer.serialize(serializer);
430-
}
431-
}
432-
// Fallback to dict serialization if `__pydantic_serializer__` is not set.
433-
// This currently only affects non-pydantic dataclasses.
434-
let dict = object_to_dict(value, $is_model, extra).map_err(py_err_se_err)?;
435-
serialize_dict!(dict)
436-
}};
437-
}
438-
439403
let ser_result = match ob_type {
440404
ObType::None => serializer.serialize_none(),
441405
ObType::Int | ObType::IntSubclass => serialize!(i64),
@@ -490,8 +454,30 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
490454
let py_url: PyMultiHostUrl = value.extract().map_err(py_err_se_err)?;
491455
serializer.serialize_str(&py_url.__str__())
492456
}
493-
ObType::Dataclass => serialize_with_serializer!(value, false),
494-
ObType::PydanticSerializable => serialize_with_serializer!(value, true),
457+
ObType::PydanticSerializable => {
458+
let py = value.py();
459+
let py_serializer = value
460+
.getattr(intern!(py, "__pydantic_serializer__"))
461+
.map_err(py_err_se_err)?;
462+
let extracted_serializer: SchemaSerializer = py_serializer.extract().map_err(py_err_se_err)?;
463+
let extra = extracted_serializer.build_extra(
464+
py,
465+
extra.mode,
466+
extra.by_alias,
467+
extra.warnings,
468+
extra.exclude_unset,
469+
extra.exclude_defaults,
470+
extra.exclude_none,
471+
extra.round_trip,
472+
extra.rec_guard,
473+
extra.serialize_unknown,
474+
extra.fallback,
475+
);
476+
let pydantic_serializer =
477+
PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra);
478+
pydantic_serializer.serialize(serializer)
479+
}
480+
ObType::Dataclass => serialize_dict!(dataclass_to_dict(value).map_err(py_err_se_err)?),
495481
ObType::Enum => {
496482
let v = value.getattr(intern!(value.py(), "value")).map_err(py_err_se_err)?;
497483
infer_serialize(v, serializer, include, exclude, extra)

src/serializers/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use config::SerializationConfig;
1111
pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
1212
use extra::{CollectWarnings, SerRecursionGuard};
1313
pub(crate) use extra::{Extra, SerMode, SerializationState};
14-
pub(crate) use shared::slots_dc_dict;
14+
pub(crate) use shared::dataclass_to_dict;
1515
pub use shared::CombinedSerializer;
1616
use shared::{to_json_bytes, BuildSerializer, TypeSerializer};
1717

src/serializers/shared.rs

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::fmt::Debug;
44
use pyo3::exceptions::PyTypeError;
55
use pyo3::once_cell::GILOnceCell;
66
use pyo3::prelude::*;
7-
use pyo3::types::{PyDict, PySet, PyString};
7+
use pyo3::types::{PyDict, PyString};
88
use pyo3::{intern, PyTraverseError, PyVisit};
99

1010
use enum_dispatch::enum_dispatch;
@@ -96,7 +96,6 @@ combined_serializer! {
9696
super::type_serializers::other::CallableBuilder;
9797
super::type_serializers::definitions::DefinitionsSerializerBuilder;
9898
super::type_serializers::dataclass::DataclassArgsBuilder;
99-
super::type_serializers::dataclass::DataclassBuilder;
10099
super::type_serializers::function::FunctionBeforeSerializerBuilder;
101100
super::type_serializers::function::FunctionAfterSerializerBuilder;
102101
super::type_serializers::function::FunctionPlainSerializerBuilder;
@@ -124,6 +123,7 @@ combined_serializer! {
124123
Generator: super::type_serializers::generator::GeneratorSerializer;
125124
Dict: super::type_serializers::dict::DictSerializer;
126125
Model: super::type_serializers::model::ModelSerializer;
126+
Dataclass: super::type_serializers::dataclass::DataclassSerializer;
127127
Url: super::type_serializers::url::UrlSerializer;
128128
MultiHostUrl: super::type_serializers::url::MultiHostUrlSerializer;
129129
Any: super::type_serializers::any::AnySerializer;
@@ -328,42 +328,23 @@ pub(crate) fn to_json_bytes(
328328
Ok(bytes)
329329
}
330330

331-
pub(super) fn object_to_dict<'py>(value: &'py PyAny, is_model: bool, extra: &Extra) -> PyResult<&'py PyDict> {
332-
let py = value.py();
333-
let attrs: &PyDict = match value.getattr(intern!(py, "__dict__")) {
334-
Ok(attr) => attr.downcast()?,
335-
Err(_) => return slots_dc_dict(value),
336-
};
337-
338-
if is_model && extra.exclude_unset {
339-
let fields_set: &PySet = value.getattr(intern!(py, "__pydantic_fields_set__"))?.downcast()?;
340-
341-
let new_attrs = attrs.copy()?;
342-
for key in new_attrs.keys() {
343-
if !fields_set.contains(key)? {
344-
new_attrs.del_item(key)?;
345-
}
346-
}
347-
Ok(new_attrs)
348-
} else {
349-
Ok(attrs)
350-
}
351-
}
352-
353331
static DC_FIELD_MARKER: GILOnceCell<PyObject> = GILOnceCell::new();
354332

355-
pub(crate) fn slots_dc_dict(dc: &PyAny) -> PyResult<&PyDict> {
356-
let py = dc.py();
357-
let dc_fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?;
358-
let dict = PyDict::new(py);
359-
360-
// need to match the logic from dataclasses.fields `tuple(f for f in fields.values() if f._field_type is _FIELD)`
333+
/// needed to match the logic from dataclasses.fields `tuple(f for f in fields.values() if f._field_type is _FIELD)`
334+
pub(super) fn get_field_marker(py: Python<'_>) -> PyResult<&PyAny> {
361335
let field_type_marker_obj = DC_FIELD_MARKER.get_or_try_init(py, || {
362336
let field_ = py.import("dataclasses")?.getattr("_FIELD")?;
363337
Ok::<PyObject, PyErr>(field_.into_py(py))
364338
})?;
365-
let field_type_marker = field_type_marker_obj.as_ref(py);
339+
Ok(field_type_marker_obj.as_ref(py))
340+
}
341+
342+
pub(crate) fn dataclass_to_dict(dc: &PyAny) -> PyResult<&PyDict> {
343+
let py = dc.py();
344+
let dc_fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?;
345+
let dict = PyDict::new(py);
366346

347+
let field_type_marker = get_field_marker(py)?;
367348
for (field_name, field) in dc_fields.iter() {
368349
let field_type = field.getattr(intern!(py, "_field_type"))?;
369350
if field_type.is(field_type_marker) {

0 commit comments

Comments
 (0)