Skip to content

Commit 50d3e24

Browse files
committed
fix dataclass support with slots, cleanup input
1 parent 7daf210 commit 50d3e24

File tree

9 files changed

+97
-72
lines changed

9 files changed

+97
-72
lines changed

pydantic_core/core_schema.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3101,7 +3101,7 @@ class DataclassSchema(TypedDict, total=False):
31013101
ref: str
31023102
metadata: Any
31033103
serialization: SerSchema
3104-
slots: bool
3104+
slots: List[str]
31053105

31063106

31073107
def dataclass_schema(
@@ -3116,7 +3116,7 @@ def dataclass_schema(
31163116
metadata: Any = None,
31173117
serialization: SerSchema | None = None,
31183118
frozen: bool | None = None,
3119-
slots: bool = False,
3119+
slots: List[str] | None = None,
31203120
) -> DataclassSchema:
31213121
"""
31223122
Returns a schema for a dataclass. As with `ModelSchema`, this schema can only be used as a field within
@@ -3134,7 +3134,7 @@ def dataclass_schema(
31343134
metadata: Any other information you want to include with the schema, not used by pydantic-core
31353135
serialization: Custom serialization schema
31363136
frozen: Whether the dataclass is frozen
3137-
slots: Whether the slots is enabled on dataclass
3137+
slots: The slots to use for the dataclass, set only if `slots=True` on the dataclass
31383138
"""
31393139
return dict_not_none(
31403140
type='dataclass',

src/input/input_abstract.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
4545
None
4646
}
4747

48-
fn is_exact_instance(&self, _class: &PyType) -> bool {
49-
false
48+
fn input_is_instance(&self, _class: &PyType) -> Option<&PyAny> {
49+
None
5050
}
5151

5252
fn is_python(&self) -> bool {

src/input/input_python.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,12 @@ impl<'a> Input<'a> for PyAny {
113113
Some(self.getattr(name))
114114
}
115115

116-
fn is_exact_instance(&self, class: &PyType) -> bool {
117-
self.get_type().is(class)
116+
fn input_is_instance(&self, class: &PyType) -> Option<&PyAny> {
117+
if self.is_instance(class).unwrap_or(false) {
118+
Some(self)
119+
} else {
120+
None
121+
}
118122
}
119123

120124
fn is_python(&self) -> bool {

src/serializers/shared.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::fmt::Debug;
33

44
use pyo3::exceptions::PyTypeError;
55
use pyo3::prelude::*;
6-
use pyo3::types::{PyDict, PySet};
6+
use pyo3::types::{PyDict, PySet, PyString, PyTuple};
77
use pyo3::{intern, PyTraverseError, PyVisit};
88

99
use enum_dispatch::enum_dispatch;
@@ -329,8 +329,18 @@ pub(crate) fn to_json_bytes(
329329

330330
pub(super) fn object_to_dict<'py>(value: &'py PyAny, is_model: bool, extra: &Extra) -> PyResult<&'py PyDict> {
331331
let py = value.py();
332-
let attr = value.getattr(intern!(py, "__dict__"))?;
333-
let attrs: &PyDict = attr.downcast()?;
332+
let attrs: &PyDict = match value.getattr(intern!(py, "__dict__")) {
333+
Ok(attr) => attr.downcast()?,
334+
Err(_) => {
335+
let slots: &PyTuple = value.getattr(intern!(py, "__slots__"))?.downcast()?;
336+
let dict = PyDict::new(py);
337+
for slot in slots {
338+
let slot: &PyString = slot.downcast()?;
339+
dict.set_item(slot, value.getattr(slot)?)?;
340+
}
341+
dict
342+
}
343+
};
334344
if is_model && extra.exclude_unset {
335345
let fields_set: &PySet = value.getattr(intern!(py, "__pydantic_fields_set__"))?.downcast()?;
336346

src/validators/dataclass.rs

Lines changed: 46 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ use ahash::AHashSet;
77

88
use crate::build_tools::{is_strict, py_err, schema_or_config_same, ExtraBehavior, SchemaDict};
99
use crate::errors::{ErrorType, ValError, ValLineError, ValResult};
10-
use crate::input::InputType;
1110
use crate::input::{GenericArguments, Input};
1211
use crate::lookup_key::LookupKey;
1312
use crate::recursion_guard::RecursionGuard;
@@ -415,7 +414,7 @@ pub struct DataclassValidator {
415414
revalidate: Revalidate,
416415
name: String,
417416
frozen: bool,
418-
slots: bool,
417+
slots: Option<Vec<Py<PyString>>>,
419418
}
420419

421420
impl BuildValidator for DataclassValidator {
@@ -442,6 +441,17 @@ impl BuildValidator for DataclassValidator {
442441
None
443442
};
444443

444+
let slots = match schema.get_as::<&PyList>(intern!(py, "slots"))? {
445+
Some(slots) => {
446+
let slots = slots
447+
.iter()
448+
.map(|s| Ok(s.downcast::<PyString>()?.into_py(py)))
449+
.collect::<PyResult<Vec<_>>>()?;
450+
Some(slots)
451+
}
452+
None => None,
453+
};
454+
445455
Ok(Self {
446456
strict: is_strict(schema, config)?,
447457
validator: Box::new(validator),
@@ -454,8 +464,7 @@ impl BuildValidator for DataclassValidator {
454464
)?)?,
455465
name,
456466
frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false),
457-
slots: matches!(class.hasattr(intern!(class.py(), "__slots__")), Ok(true))
458-
| schema.get_as(intern!(py, "slots"))?.unwrap_or(false),
467+
slots,
459468
}
460469
.into())
461470
}
@@ -477,33 +486,25 @@ impl Validator for DataclassValidator {
477486

478487
// same logic as on models
479488
let class = self.class.as_ref(py);
480-
if matches!(extra.mode, InputType::Python) && input.to_object(py).as_ref(py).is_instance(class)? {
481-
if self.revalidate.should_revalidate(input, class) {
482-
let mut validator_input = PyDict::new(py);
483-
if self.slots {
484-
let slots = input
485-
.input_get_attr(intern!(py, "__slots__"))
486-
.unwrap()?
487-
.downcast::<PyTuple>()?;
488-
for key in slots.iter() {
489-
let key: &PyString = key.downcast()?;
490-
validator_input.set_item(key, input.input_get_attr(key).unwrap()?)?;
489+
if let Some(py_input) = input.input_is_instance(class) {
490+
if self.revalidate.should_revalidate(py_input, class) {
491+
let input_dict = match py_input.getattr(intern!(py, "__dict__")) {
492+
Ok(attr) => attr,
493+
Err(_) => {
494+
// we inspect `__slots__` to get the attributes instead of using `self.slots` as a
495+
// subclass could have `slots=True`
496+
let slots: &PyTuple = py_input.getattr(intern!(py, "__slots__"))?.downcast()?;
497+
let dict = PyDict::new(py);
498+
for slot in slots {
499+
let slot: &PyString = slot.downcast()?;
500+
dict.set_item(slot, py_input.getattr(slot)?)?;
501+
}
502+
dict
491503
}
492-
} else {
493-
validator_input = input
494-
.input_get_attr(intern!(py, "__dict__"))
495-
.unwrap()?
496-
.downcast::<PyDict>()?;
497-
}
504+
};
498505
let val_output = self
499506
.validator
500-
.validate(
501-
py,
502-
validator_input.downcast::<PyAny>()?,
503-
extra,
504-
definitions,
505-
recursion_guard
506-
)?;
507+
.validate(py, input_dict, extra, definitions, recursion_guard)?;
507508
let dc = create_class(self.class.as_ref(py))?;
508509
self.set_dict_call(py, dc.as_ref(py), val_output, input)?;
509510
Ok(dc)
@@ -541,27 +542,20 @@ impl Validator for DataclassValidator {
541542
return Err(ValError::new(ErrorType::FrozenInstance, field_value));
542543
}
543544

544-
let mut dict = PyDict::new(py);
545-
let dict_py_str = intern!(py, "__dict__");
546-
547-
if self.slots {
548-
let slots = obj
549-
.input_get_attr(intern!(py, "__slots__"))
550-
.unwrap()?
551-
.downcast::<PyTuple>()?;
552-
for key in slots.iter() {
553-
let key: &PyString = key.downcast()?;
554-
dict.set_item(key, obj.input_get_attr(key).unwrap()?)?;
545+
let new_dict = if let Some(ref slots) = self.slots {
546+
let slots_dict = PyDict::new(py);
547+
for slot in slots {
548+
let slot = slot.as_ref(py);
549+
slots_dict.set_item(slot, obj.getattr(slot)?)?;
555550
}
551+
slots_dict
556552
} else {
557-
dict = obj.getattr(dict_py_str)?.downcast()?;
558-
}
553+
let dunder_dict: &PyDict = obj.getattr(intern!(py, "__dict__"))?.downcast()?;
554+
dunder_dict.copy()?
555+
};
559556

560-
let new_dict = dict.copy()?;
561557
new_dict.set_item(field_name, field_value)?;
562558

563-
// Discard the second return value, which is `init_only_args` but is always
564-
// None anyway for validate_assignment; see validate_assignment in DataclassArgsValidator
565559
let val_assignment_result = self.validator.validate_assignment(
566560
py,
567561
new_dict,
@@ -574,10 +568,11 @@ impl Validator for DataclassValidator {
574568

575569
let (dc_dict, _): (&PyDict, PyObject) = val_assignment_result.extract(py)?;
576570

577-
if self.slots {
578-
force_setattr(py, obj, field_name, field_value)?;
571+
if self.slots.is_some() {
572+
let value = dc_dict.get_item(field_name).unwrap();
573+
force_setattr(py, obj, field_name, value)?;
579574
} else {
580-
force_setattr(py, obj, dict_py_str, dc_dict)?;
575+
force_setattr(py, obj, intern!(py, "__dict__"), dc_dict)?;
581576
}
582577

583578
Ok(obj.to_object(py))
@@ -629,6 +624,7 @@ impl DataclassValidator {
629624

630625
Ok(self_instance.into_py(py))
631626
}
627+
632628
fn set_dict_call<'s, 'data>(
633629
&'s self,
634630
py: Python<'data>,
@@ -637,8 +633,9 @@ impl DataclassValidator {
637633
input: &'data impl Input<'data>,
638634
) -> ValResult<'data, ()> {
639635
let (dc_dict, post_init_kwargs): (&PyAny, &PyAny) = val_output.extract(py)?;
640-
if self.slots {
641-
for (key, value) in dc_dict.downcast::<PyDict>()?.iter() {
636+
if self.slots.is_some() {
637+
let dc_dict: &PyDict = dc_dict.downcast()?;
638+
for (key, value) in dc_dict.iter() {
642639
force_setattr(py, dc, key, value)?;
643640
}
644641
} else {

src/validators/model.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use pyo3::{ffi, intern};
88

99
use crate::build_tools::{py_err, schema_or_config_same, SchemaDict};
1010
use crate::errors::{ErrorType, ValError, ValResult};
11-
use crate::input::{py_error_on_minusone, Input, InputType};
11+
use crate::input::{py_error_on_minusone, Input};
1212
use crate::recursion_guard::RecursionGuard;
1313

1414
use super::function::convert_err;
@@ -37,11 +37,11 @@ impl Revalidate {
3737
}
3838
}
3939

40-
pub fn should_revalidate<'d>(&self, input: &impl Input<'d>, class: &PyType) -> bool {
40+
pub fn should_revalidate(&self, input: &PyAny, class: &PyType) -> bool {
4141
match self {
4242
Revalidate::Always => true,
4343
Revalidate::Never => false,
44-
Revalidate::SubclassInstances => !input.is_exact_instance(class),
44+
Revalidate::SubclassInstances => !input.get_type().is(class),
4545
}
4646
}
4747
}
@@ -125,16 +125,16 @@ impl Validator for ModelValidator {
125125
// if the input is an instance of the class, we "revalidate" it - e.g. we extract and reuse `__pydantic_fields_set__`
126126
// but use from attributes to create a new instance of the model field type
127127
let class = self.class.as_ref(py);
128-
if matches!(extra.mode, InputType::Python) && input.to_object(py).as_ref(py).is_instance(class)? {
129-
if self.revalidate.should_revalidate(input, class) {
128+
if let Some(py_input) = input.input_is_instance(class) {
129+
if self.revalidate.should_revalidate(py_input, class) {
130130
if self.root_model {
131-
let inner_input: &PyAny = input.input_get_attr(intern!(py, ROOT_FIELD)).unwrap()?;
131+
let inner_input = py_input.getattr(intern!(py, ROOT_FIELD))?;
132132
self.validate_construct(py, inner_input, None, extra, definitions, recursion_guard)
133133
} else {
134-
let fields_set = input.input_get_attr(intern!(py, DUNDER_FIELDS_SET_KEY)).unwrap()?;
134+
let fields_set = py_input.getattr(intern!(py, DUNDER_FIELDS_SET_KEY))?;
135135
// get dict here so from_attributes logic doesn't apply
136-
let dict = input.input_get_attr(intern!(py, DUNDER_DICT)).unwrap()?;
137-
let model_extra = input.input_get_attr(intern!(py, DUNDER_MODEL_EXTRA_KEY)).unwrap()?;
136+
let dict = py_input.getattr(intern!(py, DUNDER_DICT))?;
137+
let model_extra = py_input.getattr(intern!(py, DUNDER_MODEL_EXTRA_KEY))?;
138138

139139
let inner_input: &PyAny = if model_extra.is_none() {
140140
dict

tests/serializers/test_any.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,3 +473,14 @@ def __init__(self, a: str, b: bytes):
473473
assert j == b'{"a":"hello"}'
474474

475475
assert s.to_python(Foo(a='hello', b=b'more'), exclude={'a'}) == IsStrictDict()
476+
477+
478+
@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10')
479+
def test_dataclass_slots(any_serializer):
480+
@dataclasses.dataclass(slots=True)
481+
class Foo:
482+
a: int
483+
b: str
484+
485+
foo = Foo(1, 'a')
486+
assert any_serializer.to_python(foo) == IsStrictDict(a=1, b='a')

tests/test_schema_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,13 +250,13 @@ def args(*args, **kwargs):
250250
core_schema.dataclass_schema,
251251
# MyModel should be a dataclass, but I'm being lazy here
252252
args(MyModel, {'type': 'int'}),
253-
{'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyModel, 'slots': False},
253+
{'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyModel},
254254
),
255255
(
256256
core_schema.dataclass_schema,
257257
# MyModel should be a dataclass, but I'm being lazy here
258-
args(MyModel, {'type': 'int'}, slots=True),
259-
{'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyModel, 'slots': True},
258+
args(MyModel, {'type': 'int'}, slots=['a']),
259+
{'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyModel, 'slots': ['a']},
260260
),
261261
]
262262

tests/validators/test_dataclasses.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1204,6 +1204,7 @@ class Model:
12041204
core_schema.dataclass_args_schema(
12051205
'Model', [core_schema.dataclass_field(name='x', schema=core_schema.int_schema())]
12061206
),
1207+
slots=['x'],
12071208
)
12081209

12091210
val = SchemaValidator(schema)
@@ -1248,6 +1249,7 @@ def validate_b(cls, v: bytes, info: core_schema.FieldValidationInfo) -> bytes:
12481249
),
12491250
],
12501251
),
1252+
slots=['a', 'b'],
12511253
)
12521254

12531255
v = SchemaValidator(schema)
@@ -1281,6 +1283,7 @@ def validate_b(cls, v: str, info: core_schema.FieldValidationInfo) -> str:
12811283
),
12821284
],
12831285
),
1286+
slots=['a', 'b'],
12841287
)
12851288

12861289
v = SchemaValidator(schema)
@@ -1366,7 +1369,7 @@ def test_slots_dataclass_subclass(revalidate_instances, input_value, expected):
13661369
extra_behavior='forbid',
13671370
),
13681371
revalidate_instances=revalidate_instances,
1369-
slots=True,
1372+
slots=['a', 'b'],
13701373
)
13711374
v = SchemaValidator(schema)
13721375

0 commit comments

Comments
 (0)