Skip to content

Add slots to dataclass schema #617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 25, 2023
Merged
6 changes: 5 additions & 1 deletion pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3101,6 +3101,7 @@ class DataclassSchema(TypedDict, total=False):
ref: str
metadata: Any
serialization: SerSchema
slots: List[str]


def dataclass_schema(
Expand All @@ -3115,13 +3116,14 @@ def dataclass_schema(
metadata: Any = None,
serialization: SerSchema | None = None,
frozen: bool | None = None,
slots: List[str] | None = None,
) -> DataclassSchema:
"""
Returns a schema for a dataclass. As with `ModelSchema`, this schema can only be used as a field within
another schema, not as the root type.

Args:
cls: The dataclass type, used to to perform subclass checks
cls: The dataclass type, used to perform subclass checks
schema: The schema to use for the dataclass fields
cls_name: The name to use in error locs, etc; this is useful for generics (default: `cls.__name__`)
post_init: Whether to call `__post_init__` after validation
Expand All @@ -3132,6 +3134,7 @@ def dataclass_schema(
metadata: Any other information you want to include with the schema, not used by pydantic-core
serialization: Custom serialization schema
frozen: Whether the dataclass is frozen
slots: The slots to use for the dataclass, set only if `slots=True` on the dataclass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
slots: The slots to use for the dataclass, set only if `slots=True` on the dataclass
slots: The slots to use for the dataclass, set only if `slots=True` on the dataclass or one of its bases

"""
return dict_not_none(
type='dataclass',
Expand All @@ -3145,6 +3148,7 @@ def dataclass_schema(
metadata=metadata,
serialization=serialization,
frozen=frozen,
slots=slots,
)


Expand Down
9 changes: 2 additions & 7 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fmt;

use pyo3::types::{PyDict, PyString, PyType};
use pyo3::types::{PyDict, PyType};
use pyo3::{intern, prelude::*};

use crate::errors::{InputValue, LocItem, ValResult};
Expand Down Expand Up @@ -40,15 +40,10 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {

fn is_none(&self) -> bool;

#[cfg_attr(has_no_coverage, no_coverage)]
fn input_get_attr(&self, _name: &PyString) -> Option<PyResult<&PyAny>> {
fn input_is_instance(&self, _class: &PyType) -> Option<&PyAny> {
None
}

fn is_exact_instance(&self, _class: &PyType) -> bool {
false
}

fn is_python(&self) -> bool {
false
}
Expand Down
12 changes: 6 additions & 6 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@ impl<'a> Input<'a> for PyAny {
self.is_none()
}

fn input_get_attr(&self, name: &PyString) -> Option<PyResult<&PyAny>> {
Some(self.getattr(name))
}

fn is_exact_instance(&self, class: &PyType) -> bool {
self.get_type().is(class)
fn input_is_instance(&self, class: &PyType) -> Option<&PyAny> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big deal, but given this returns an Option<&PyAny> rather than a bool the name seems a bit weird to me (I would expect it to return a bool since the name sounds like an assertion). I'd suggest downcast or similar, but also fine keeping as is if you prefer.

if self.is_instance(class).unwrap_or(false) {
Some(self)
} else {
None
}
}

fn is_python(&self) -> bool {
Expand Down
106 changes: 46 additions & 60 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use super::errors::{py_err_se_err, PydanticSerializationError};
use super::extra::{Extra, SerMode};
use super::filter::AnyFilter;
use super::ob_type::ObType;
use super::shared::object_to_dict;
use super::shared::dataclass_to_dict;

pub(crate) fn infer_to_python(
value: &PyAny,
Expand Down Expand Up @@ -97,29 +97,23 @@ pub(crate) fn infer_to_python_known(
Ok::<PyObject, PyErr>(new_dict.into_py(py))
};

let serialize_with_serializer = |value: &PyAny, is_model: bool| {
if let Ok(py_serializer) = value.getattr(intern!(py, "__pydantic_serializer__")) {
if let Ok(serializer) = py_serializer.extract::<SchemaSerializer>() {
let extra = serializer.build_extra(
py,
extra.mode,
extra.by_alias,
extra.warnings,
extra.exclude_unset,
extra.exclude_defaults,
extra.exclude_none,
extra.round_trip,
extra.rec_guard,
extra.serialize_unknown,
extra.fallback,
);
return serializer.serializer.to_python(value, include, exclude, &extra);
}
}
// Fallback to dict serialization if `__pydantic_serializer__` is not set.
// This currently only affects non-pydantic dataclasses.
let dict = object_to_dict(value, is_model, extra)?;
serialize_dict(dict)
let serialize_with_serializer = || {
let py_serializer = value.getattr(intern!(py, "__pydantic_serializer__"))?;
let serializer: SchemaSerializer = py_serializer.extract()?;
let extra = serializer.build_extra(
py,
extra.mode,
extra.by_alias,
extra.warnings,
extra.exclude_unset,
extra.exclude_defaults,
extra.exclude_none,
extra.round_trip,
extra.rec_guard,
extra.serialize_unknown,
extra.fallback,
);
serializer.serializer.to_python(value, include, exclude, &extra)
};

let value = match extra.mode {
Expand Down Expand Up @@ -191,8 +185,8 @@ pub(crate) fn infer_to_python_known(
let py_url: PyMultiHostUrl = value.extract()?;
py_url.__str__().into_py(py)
}
ObType::PydanticSerializable => serialize_with_serializer(value, true)?,
ObType::Dataclass => serialize_with_serializer(value, false)?,
ObType::PydanticSerializable => serialize_with_serializer()?,
ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?,
ObType::Enum => {
let v = value.getattr(intern!(py, "value"))?;
infer_to_python(v, include, exclude, extra)?.into_py(py)
Expand Down Expand Up @@ -257,8 +251,8 @@ pub(crate) fn infer_to_python_known(
}
new_dict.into_py(py)
}
ObType::PydanticSerializable => serialize_with_serializer(value, true)?,
ObType::Dataclass => serialize_with_serializer(value, false)?,
ObType::PydanticSerializable => serialize_with_serializer()?,
ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?,
ObType::Generator => {
let iter = super::type_serializers::generator::SerializationIterator::new(
value.downcast()?,
Expand Down Expand Up @@ -406,36 +400,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
}};
}

macro_rules! serialize_with_serializer {
($py_serializable:expr, $is_model:expr) => {{
let py = $py_serializable.py();
if let Ok(py_serializer) = value.getattr(intern!(py, "__pydantic_serializer__")) {
if let Ok(extracted_serializer) = py_serializer.extract::<SchemaSerializer>() {
let extra = extracted_serializer.build_extra(
py,
extra.mode,
extra.by_alias,
extra.warnings,
extra.exclude_unset,
extra.exclude_defaults,
extra.exclude_none,
extra.round_trip,
extra.rec_guard,
extra.serialize_unknown,
extra.fallback,
);
let pydantic_serializer =
PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra);
return pydantic_serializer.serialize(serializer);
}
}
// Fallback to dict serialization if `__pydantic_serializer__` is not set.
// This currently only affects non-pydantic dataclasses.
let dict = object_to_dict(value, $is_model, extra).map_err(py_err_se_err)?;
serialize_dict!(dict)
}};
}

let ser_result = match ob_type {
ObType::None => serializer.serialize_none(),
ObType::Int | ObType::IntSubclass => serialize!(i64),
Expand Down Expand Up @@ -490,8 +454,30 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
let py_url: PyMultiHostUrl = value.extract().map_err(py_err_se_err)?;
serializer.serialize_str(&py_url.__str__())
}
ObType::Dataclass => serialize_with_serializer!(value, false),
ObType::PydanticSerializable => serialize_with_serializer!(value, true),
ObType::PydanticSerializable => {
let py = value.py();
let py_serializer = value
.getattr(intern!(py, "__pydantic_serializer__"))
.map_err(py_err_se_err)?;
let extracted_serializer: SchemaSerializer = py_serializer.extract().map_err(py_err_se_err)?;
let extra = extracted_serializer.build_extra(
py,
extra.mode,
extra.by_alias,
extra.warnings,
extra.exclude_unset,
extra.exclude_defaults,
extra.exclude_none,
extra.round_trip,
extra.rec_guard,
extra.serialize_unknown,
extra.fallback,
);
let pydantic_serializer =
PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra);
pydantic_serializer.serialize(serializer)
}
ObType::Dataclass => serialize_dict!(dataclass_to_dict(value).map_err(py_err_se_err)?),
ObType::Enum => {
let v = value.getattr(intern!(value.py(), "value")).map_err(py_err_se_err)?;
infer_serialize(v, serializer, include, exclude, extra)
Expand Down
1 change: 1 addition & 0 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use config::SerializationConfig;
pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
use extra::{CollectWarnings, SerRecursionGuard};
pub(crate) use extra::{Extra, SerMode, SerializationState};
pub(crate) use shared::dataclass_to_dict;
pub use shared::CombinedSerializer;
use shared::{to_json_bytes, BuildSerializer, TypeSerializer};

Expand Down
43 changes: 26 additions & 17 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use std::borrow::Cow;
use std::fmt::Debug;

use pyo3::exceptions::PyTypeError;
use pyo3::once_cell::GILOnceCell;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PySet};
use pyo3::types::{PyDict, PyString};
use pyo3::{intern, PyTraverseError, PyVisit};

use enum_dispatch::enum_dispatch;
Expand Down Expand Up @@ -95,7 +96,6 @@ combined_serializer! {
super::type_serializers::other::CallableBuilder;
super::type_serializers::definitions::DefinitionsSerializerBuilder;
super::type_serializers::dataclass::DataclassArgsBuilder;
super::type_serializers::dataclass::DataclassBuilder;
super::type_serializers::function::FunctionBeforeSerializerBuilder;
super::type_serializers::function::FunctionAfterSerializerBuilder;
super::type_serializers::function::FunctionPlainSerializerBuilder;
Expand Down Expand Up @@ -123,6 +123,7 @@ combined_serializer! {
Generator: super::type_serializers::generator::GeneratorSerializer;
Dict: super::type_serializers::dict::DictSerializer;
Model: super::type_serializers::model::ModelSerializer;
Dataclass: super::type_serializers::dataclass::DataclassSerializer;
Url: super::type_serializers::url::UrlSerializer;
MultiHostUrl: super::type_serializers::url::MultiHostUrlSerializer;
Any: super::type_serializers::any::AnySerializer;
Expand Down Expand Up @@ -327,21 +328,29 @@ pub(crate) fn to_json_bytes(
Ok(bytes)
}

pub(super) fn object_to_dict<'py>(value: &'py PyAny, is_model: bool, extra: &Extra) -> PyResult<&'py PyDict> {
let py = value.py();
let attr = value.getattr(intern!(py, "__dict__"))?;
let attrs: &PyDict = attr.downcast()?;
if is_model && extra.exclude_unset {
let fields_set: &PySet = value.getattr(intern!(py, "__pydantic_fields_set__"))?.downcast()?;

let new_attrs = attrs.copy()?;
for key in new_attrs.keys() {
if !fields_set.contains(key)? {
new_attrs.del_item(key)?;
}
static DC_FIELD_MARKER: GILOnceCell<PyObject> = GILOnceCell::new();

/// needed to match the logic from dataclasses.fields `tuple(f for f in fields.values() if f._field_type is _FIELD)`
pub(super) fn get_field_marker(py: Python<'_>) -> PyResult<&PyAny> {
let field_type_marker_obj = DC_FIELD_MARKER.get_or_try_init(py, || {
let field_ = py.import("dataclasses")?.getattr("_FIELD")?;
Ok::<PyObject, PyErr>(field_.into_py(py))
})?;
Ok(field_type_marker_obj.as_ref(py))
}

pub(crate) fn dataclass_to_dict(dc: &PyAny) -> PyResult<&PyDict> {
let py = dc.py();
let dc_fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?;
let dict = PyDict::new(py);

let field_type_marker = get_field_marker(py)?;
for (field_name, field) in dc_fields.iter() {
let field_type = field.getattr(intern!(py, "_field_type"))?;
if field_type.is(field_type_marker) {
let field_name: &PyString = field_name.downcast()?;
dict.set_item(field_name, dc.getattr(field_name)?)?;
}
Ok(new_attrs)
} else {
Ok(attrs)
}
Ok(dict)
}
Loading