Skip to content

Commit 26fa27d

Browse files
hramezaniadriangbsamuelcolvindmontagu
authored
Add slots to dataclass schema (#617)
Co-authored-by: Adrian Garcia Badaracco <[email protected]> Co-authored-by: Samuel Colvin <[email protected]> Co-authored-by: David Montague <[email protected]>
1 parent 058a0b8 commit 26fa27d

File tree

15 files changed

+718
-147
lines changed

15 files changed

+718
-147
lines changed

pydantic_core/core_schema.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3093,6 +3093,7 @@ class DataclassSchema(TypedDict, total=False):
30933093
type: Required[Literal['dataclass']]
30943094
cls: Required[Type[Any]]
30953095
schema: Required[CoreSchema]
3096+
fields: Required[List[str]]
30963097
cls_name: str
30973098
post_init: bool # default: False
30983099
revalidate_instances: Literal['always', 'never', 'subclass-instances'] # default: 'never'
@@ -3101,11 +3102,13 @@ class DataclassSchema(TypedDict, total=False):
31013102
ref: str
31023103
metadata: Any
31033104
serialization: SerSchema
3105+
slots: bool
31043106

31053107

31063108
def dataclass_schema(
31073109
cls: Type[Any],
31083110
schema: CoreSchema,
3111+
fields: List[str],
31093112
*,
31103113
cls_name: str | None = None,
31113114
post_init: bool | None = None,
@@ -3115,14 +3118,17 @@ def dataclass_schema(
31153118
metadata: Any = None,
31163119
serialization: SerSchema | None = None,
31173120
frozen: bool | None = None,
3121+
slots: bool | None = None,
31183122
) -> DataclassSchema:
31193123
"""
31203124
Returns a schema for a dataclass. As with `ModelSchema`, this schema can only be used as a field within
31213125
another schema, not as the root type.
31223126
31233127
Args:
3124-
cls: The dataclass type, used to to perform subclass checks
3128+
cls: The dataclass type, used to perform subclass checks
31253129
schema: The schema to use for the dataclass fields
3130+
fields: Fields of the dataclass, this is used in serialization and in validation during re-validation
3131+
and while validating assignment
31263132
cls_name: The name to use in error locs, etc; this is useful for generics (default: `cls.__name__`)
31273133
post_init: Whether to call `__post_init__` after validation
31283134
revalidate_instances: whether instances of models and dataclasses (including subclass instances)
@@ -3132,10 +3138,13 @@ def dataclass_schema(
31323138
metadata: Any other information you want to include with the schema, not used by pydantic-core
31333139
serialization: Custom serialization schema
31343140
frozen: Whether the dataclass is frozen
3141+
slots: Whether `slots=True` on the dataclass, means each field is assigned independently, rather than
3142+
simply setting `__dict__`, default false
31353143
"""
31363144
return dict_not_none(
31373145
type='dataclass',
31383146
cls=cls,
3147+
fields=fields,
31393148
cls_name=cls_name,
31403149
schema=schema,
31413150
post_init=post_init,
@@ -3145,6 +3154,7 @@ def dataclass_schema(
31453154
metadata=metadata,
31463155
serialization=serialization,
31473156
frozen=frozen,
3157+
slots=slots,
31483158
)
31493159

31503160

src/input/input_abstract.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::fmt;
22

3-
use pyo3::types::{PyDict, PyString, PyType};
3+
use pyo3::types::{PyDict, PyType};
44
use pyo3::{intern, prelude::*};
55

66
use crate::errors::{InputValue, LocItem, ValResult};
@@ -40,15 +40,10 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
4040

4141
fn is_none(&self) -> bool;
4242

43-
#[cfg_attr(has_no_coverage, no_coverage)]
44-
fn input_get_attr(&self, _name: &PyString) -> Option<PyResult<&PyAny>> {
43+
fn input_is_instance(&self, _class: &PyType) -> Option<&PyAny> {
4544
None
4645
}
4746

48-
fn is_exact_instance(&self, _class: &PyType) -> bool {
49-
false
50-
}
51-
5247
fn is_python(&self) -> bool {
5348
false
5449
}

src/input/input_python.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,12 @@ impl<'a> Input<'a> for PyAny {
108108
self.is_none()
109109
}
110110

111-
fn input_get_attr(&self, name: &PyString) -> Option<PyResult<&PyAny>> {
112-
Some(self.getattr(name))
113-
}
114-
115-
fn is_exact_instance(&self, class: &PyType) -> bool {
116-
self.get_type().is(class)
111+
fn input_is_instance(&self, class: &PyType) -> Option<&PyAny> {
112+
if self.is_instance(class).unwrap_or(false) {
113+
Some(self)
114+
} else {
115+
None
116+
}
117117
}
118118

119119
fn is_python(&self) -> bool {

src/serializers/infer.rs

Lines changed: 46 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use super::errors::{py_err_se_err, PydanticSerializationError};
2121
use super::extra::{Extra, SerMode};
2222
use super::filter::AnyFilter;
2323
use super::ob_type::ObType;
24-
use super::shared::object_to_dict;
24+
use super::shared::dataclass_to_dict;
2525

2626
pub(crate) fn infer_to_python(
2727
value: &PyAny,
@@ -98,29 +98,23 @@ pub(crate) fn infer_to_python_known(
9898
Ok::<PyObject, PyErr>(new_dict.into_py(py))
9999
};
100100

101-
let serialize_with_serializer = |value: &PyAny, is_model: bool| {
102-
if let Ok(py_serializer) = value.getattr(intern!(py, "__pydantic_serializer__")) {
103-
if let Ok(serializer) = py_serializer.extract::<SchemaSerializer>() {
104-
let extra = serializer.build_extra(
105-
py,
106-
extra.mode,
107-
extra.by_alias,
108-
extra.warnings,
109-
extra.exclude_unset,
110-
extra.exclude_defaults,
111-
extra.exclude_none,
112-
extra.round_trip,
113-
extra.rec_guard,
114-
extra.serialize_unknown,
115-
extra.fallback,
116-
);
117-
return serializer.serializer.to_python(value, include, exclude, &extra);
118-
}
119-
}
120-
// Fallback to dict serialization if `__pydantic_serializer__` is not set.
121-
// This currently only affects non-pydantic dataclasses.
122-
let dict = object_to_dict(value, is_model, extra)?;
123-
serialize_dict(dict)
101+
let serialize_with_serializer = || {
102+
let py_serializer = value.getattr(intern!(py, "__pydantic_serializer__"))?;
103+
let serializer: SchemaSerializer = py_serializer.extract()?;
104+
let extra = serializer.build_extra(
105+
py,
106+
extra.mode,
107+
extra.by_alias,
108+
extra.warnings,
109+
extra.exclude_unset,
110+
extra.exclude_defaults,
111+
extra.exclude_none,
112+
extra.round_trip,
113+
extra.rec_guard,
114+
extra.serialize_unknown,
115+
extra.fallback,
116+
);
117+
serializer.serializer.to_python(value, include, exclude, &extra)
124118
};
125119

126120
let value = match extra.mode {
@@ -192,8 +186,8 @@ pub(crate) fn infer_to_python_known(
192186
let py_url: PyMultiHostUrl = value.extract()?;
193187
py_url.__str__().into_py(py)
194188
}
195-
ObType::PydanticSerializable => serialize_with_serializer(value, true)?,
196-
ObType::Dataclass => serialize_with_serializer(value, false)?,
189+
ObType::PydanticSerializable => serialize_with_serializer()?,
190+
ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?,
197191
ObType::Enum => {
198192
let v = value.getattr(intern!(py, "value"))?;
199193
infer_to_python(v, include, exclude, extra)?.into_py(py)
@@ -258,8 +252,8 @@ pub(crate) fn infer_to_python_known(
258252
}
259253
new_dict.into_py(py)
260254
}
261-
ObType::PydanticSerializable => serialize_with_serializer(value, true)?,
262-
ObType::Dataclass => serialize_with_serializer(value, false)?,
255+
ObType::PydanticSerializable => serialize_with_serializer()?,
256+
ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?,
263257
ObType::Generator => {
264258
let iter = super::type_serializers::generator::SerializationIterator::new(
265259
value.downcast()?,
@@ -411,36 +405,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
411405
}};
412406
}
413407

414-
macro_rules! serialize_with_serializer {
415-
($py_serializable:expr, $is_model:expr) => {{
416-
let py = $py_serializable.py();
417-
if let Ok(py_serializer) = value.getattr(intern!(py, "__pydantic_serializer__")) {
418-
if let Ok(extracted_serializer) = py_serializer.extract::<SchemaSerializer>() {
419-
let extra = extracted_serializer.build_extra(
420-
py,
421-
extra.mode,
422-
extra.by_alias,
423-
extra.warnings,
424-
extra.exclude_unset,
425-
extra.exclude_defaults,
426-
extra.exclude_none,
427-
extra.round_trip,
428-
extra.rec_guard,
429-
extra.serialize_unknown,
430-
extra.fallback,
431-
);
432-
let pydantic_serializer =
433-
PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra);
434-
return pydantic_serializer.serialize(serializer);
435-
}
436-
}
437-
// Fallback to dict serialization if `__pydantic_serializer__` is not set.
438-
// This currently only affects non-pydantic dataclasses.
439-
let dict = object_to_dict(value, $is_model, extra).map_err(py_err_se_err)?;
440-
serialize_dict!(dict)
441-
}};
442-
}
443-
444408
let ser_result = match ob_type {
445409
ObType::None => serializer.serialize_none(),
446410
ObType::Int | ObType::IntSubclass => serialize!(i64),
@@ -495,8 +459,30 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
495459
let py_url: PyMultiHostUrl = value.extract().map_err(py_err_se_err)?;
496460
serializer.serialize_str(&py_url.__str__())
497461
}
498-
ObType::Dataclass => serialize_with_serializer!(value, false),
499-
ObType::PydanticSerializable => serialize_with_serializer!(value, true),
462+
ObType::PydanticSerializable => {
463+
let py = value.py();
464+
let py_serializer = value
465+
.getattr(intern!(py, "__pydantic_serializer__"))
466+
.map_err(py_err_se_err)?;
467+
let extracted_serializer: SchemaSerializer = py_serializer.extract().map_err(py_err_se_err)?;
468+
let extra = extracted_serializer.build_extra(
469+
py,
470+
extra.mode,
471+
extra.by_alias,
472+
extra.warnings,
473+
extra.exclude_unset,
474+
extra.exclude_defaults,
475+
extra.exclude_none,
476+
extra.round_trip,
477+
extra.rec_guard,
478+
extra.serialize_unknown,
479+
extra.fallback,
480+
);
481+
let pydantic_serializer =
482+
PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra);
483+
pydantic_serializer.serialize(serializer)
484+
}
485+
ObType::Dataclass => serialize_dict!(dataclass_to_dict(value).map_err(py_err_se_err)?),
500486
ObType::Enum => {
501487
let v = value.getattr(intern!(value.py(), "value")).map_err(py_err_se_err)?;
502488
infer_serialize(v, serializer, include, exclude, extra)

src/serializers/shared.rs

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ use std::borrow::Cow;
22
use std::fmt::Debug;
33

44
use pyo3::exceptions::PyTypeError;
5+
use pyo3::once_cell::GILOnceCell;
56
use pyo3::prelude::*;
6-
use pyo3::types::{PyDict, PySet};
7+
use pyo3::types::{PyDict, PyString};
78
use pyo3::{intern, PyTraverseError, PyVisit};
89

910
use enum_dispatch::enum_dispatch;
@@ -95,7 +96,6 @@ combined_serializer! {
9596
super::type_serializers::other::CallableBuilder;
9697
super::type_serializers::definitions::DefinitionsSerializerBuilder;
9798
super::type_serializers::dataclass::DataclassArgsBuilder;
98-
super::type_serializers::dataclass::DataclassBuilder;
9999
super::type_serializers::function::FunctionBeforeSerializerBuilder;
100100
super::type_serializers::function::FunctionAfterSerializerBuilder;
101101
super::type_serializers::function::FunctionPlainSerializerBuilder;
@@ -123,6 +123,7 @@ combined_serializer! {
123123
Generator: super::type_serializers::generator::GeneratorSerializer;
124124
Dict: super::type_serializers::dict::DictSerializer;
125125
Model: super::type_serializers::model::ModelSerializer;
126+
Dataclass: super::type_serializers::dataclass::DataclassSerializer;
126127
Url: super::type_serializers::url::UrlSerializer;
127128
MultiHostUrl: super::type_serializers::url::MultiHostUrlSerializer;
128129
Any: super::type_serializers::any::AnySerializer;
@@ -327,21 +328,29 @@ pub(crate) fn to_json_bytes(
327328
Ok(bytes)
328329
}
329330

330-
pub(super) fn object_to_dict<'py>(value: &'py PyAny, is_model: bool, extra: &Extra) -> PyResult<&'py PyDict> {
331-
let py = value.py();
332-
let attr = value.getattr(intern!(py, "__dict__"))?;
333-
let attrs: &PyDict = attr.downcast()?;
334-
if is_model && extra.exclude_unset {
335-
let fields_set: &PySet = value.getattr(intern!(py, "__pydantic_fields_set__"))?.downcast()?;
336-
337-
let new_attrs = attrs.copy()?;
338-
for key in new_attrs.keys() {
339-
if !fields_set.contains(key)? {
340-
new_attrs.del_item(key)?;
341-
}
331+
static DC_FIELD_MARKER: GILOnceCell<PyObject> = GILOnceCell::new();
332+
333+
/// needed to match the logic from dataclasses.fields `tuple(f for f in fields.values() if f._field_type is _FIELD)`
334+
pub(super) fn get_field_marker(py: Python<'_>) -> PyResult<&PyAny> {
335+
let field_type_marker_obj = DC_FIELD_MARKER.get_or_try_init(py, || {
336+
let field_ = py.import("dataclasses")?.getattr("_FIELD")?;
337+
Ok::<PyObject, PyErr>(field_.into_py(py))
338+
})?;
339+
Ok(field_type_marker_obj.as_ref(py))
340+
}
341+
342+
pub(super) fn dataclass_to_dict(dc: &PyAny) -> PyResult<&PyDict> {
343+
let py = dc.py();
344+
let dc_fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?;
345+
let dict = PyDict::new(py);
346+
347+
let field_type_marker = get_field_marker(py)?;
348+
for (field_name, field) in dc_fields.iter() {
349+
let field_type = field.getattr(intern!(py, "_field_type"))?;
350+
if field_type.is(field_type_marker) {
351+
let field_name: &PyString = field_name.downcast()?;
352+
dict.set_item(field_name, dc.getattr(field_name)?)?;
342353
}
343-
Ok(new_attrs)
344-
} else {
345-
Ok(attrs)
346354
}
355+
Ok(dict)
347356
}

0 commit comments

Comments
 (0)