Skip to content

Commit 559e0c6

Browse files
committed
add dataclass.fields to schema
1 parent a0826b1 commit 559e0c6

File tree

11 files changed

+109
-65
lines changed

11 files changed

+109
-65
lines changed

pydantic_core/core_schema.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3093,6 +3093,7 @@ class DataclassSchema(TypedDict, total=False):
30933093
type: Required[Literal['dataclass']]
30943094
cls: Required[Type[Any]]
30953095
schema: Required[CoreSchema]
3096+
fields: Required[List[str]]
30963097
cls_name: str
30973098
post_init: bool # default: False
30983099
revalidate_instances: Literal['always', 'never', 'subclass-instances'] # default: 'never'
@@ -3101,12 +3102,13 @@ class DataclassSchema(TypedDict, total=False):
31013102
ref: str
31023103
metadata: Any
31033104
serialization: SerSchema
3104-
slots: List[str]
3105+
slots: bool
31053106

31063107

31073108
def dataclass_schema(
31083109
cls: Type[Any],
31093110
schema: CoreSchema,
3111+
fields: List[str],
31103112
*,
31113113
cls_name: str | None = None,
31123114
post_init: bool | None = None,
@@ -3116,7 +3118,7 @@ def dataclass_schema(
31163118
metadata: Any = None,
31173119
serialization: SerSchema | None = None,
31183120
frozen: bool | None = None,
3119-
slots: List[str] | None = None,
3121+
slots: bool | None = None,
31203122
) -> DataclassSchema:
31213123
"""
31223124
Returns a schema for a dataclass. As with `ModelSchema`, this schema can only be used as a field within
@@ -3125,6 +3127,8 @@ def dataclass_schema(
31253127
Args:
31263128
cls: The dataclass type, used to perform subclass checks
31273129
schema: The schema to use for the dataclass fields
3130+
fields: Fields of the dataclass, this is used in serialization and in validation during re-validation
3131+
and while validating assignment
31283132
cls_name: The name to use in error locs, etc; this is useful for generics (default: `cls.__name__`)
31293133
post_init: Whether to call `__post_init__` after validation
31303134
revalidate_instances: whether instances of models and dataclasses (including subclass instances)
@@ -3134,11 +3138,13 @@ def dataclass_schema(
31343138
metadata: Any other information you want to include with the schema, not used by pydantic-core
31353139
serialization: Custom serialization schema
31363140
frozen: Whether the dataclass is frozen
3137-
slots: The slots to use for the dataclass, set only if `slots=True` on the dataclass
3141+
slots: Whether `slots=True` on the dataclass, means each field is assigned independently, rather than
3142+
simply setting `__dict__`, default false
31383143
"""
31393144
return dict_not_none(
31403145
type='dataclass',
31413146
cls=cls,
3147+
fields=fields,
31423148
cls_name=cls_name,
31433149
schema=schema,
31443150
post_init=post_init,

src/serializers/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ use config::SerializationConfig;
1111
pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
1212
use extra::{CollectWarnings, SerRecursionGuard};
1313
pub(crate) use extra::{Extra, SerMode, SerializationState};
14-
pub(crate) use shared::dataclass_to_dict;
1514
pub use shared::CombinedSerializer;
1615
use shared::{to_json_bytes, BuildSerializer, TypeSerializer};
1716

src/serializers/shared.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ pub(super) fn get_field_marker(py: Python<'_>) -> PyResult<&PyAny> {
339339
Ok(field_type_marker_obj.as_ref(py))
340340
}
341341

342-
pub(crate) fn dataclass_to_dict(dc: &PyAny) -> PyResult<&PyDict> {
342+
pub(super) fn dataclass_to_dict(dc: &PyAny) -> PyResult<&PyDict> {
343343
let py = dc.py();
344344
let dc_fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?;
345345
let dict = PyDict::new(py);

src/serializers/type_serializers/dataclass.rs

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ use crate::build_tools::{py_error_type, ExtraBehavior, SchemaDict};
99
use crate::definitions::DefinitionsBuilder;
1010

1111
use super::{
12-
get_field_marker, infer_json_key, infer_json_key_known, infer_serialize, infer_to_python, py_err_se_err,
13-
BuildSerializer, CombinedSerializer, ComputedFields, Extra, FieldsMode, GeneralFieldsSerializer, ObType, SerCheck,
14-
SerField, TypeSerializer,
12+
infer_json_key, infer_json_key_known, infer_serialize, infer_to_python, py_err_se_err, BuildSerializer,
13+
CombinedSerializer, ComputedFields, Extra, FieldsMode, GeneralFieldsSerializer, ObType, SerCheck, SerField,
14+
TypeSerializer,
1515
};
1616

1717
pub struct DataclassArgsBuilder;
@@ -83,17 +83,11 @@ impl BuildSerializer for DataclassSerializer {
8383
let sub_schema: &PyDict = schema.get_as_req(intern!(py, "schema"))?;
8484
let serializer = Box::new(CombinedSerializer::build(sub_schema, config, definitions)?);
8585

86-
let dc_fields: &PyDict = class.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?;
87-
let mut fields = Vec::with_capacity(dc_fields.len());
88-
89-
let field_type_marker = get_field_marker(py)?;
90-
for (field_name, field) in dc_fields.iter() {
91-
let field_type = field.getattr(intern!(py, "_field_type"))?;
92-
if field_type.is(field_type_marker) {
93-
let field_name: &PyString = field_name.downcast()?;
94-
fields.push(field_name.into_py(py));
95-
}
96-
}
86+
let fields = schema
87+
.get_as_req::<&PyList>(intern!(py, "fields"))?
88+
.iter()
89+
.map(|s| Ok(s.downcast::<PyString>()?.into_py(py)))
90+
.collect::<PyResult<Vec<_>>>()?;
9791

9892
Ok(Self {
9993
class: class.into(),

src/serializers/type_serializers/mod.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,4 @@ pub(self) use super::infer::{
3535
infer_to_python_known,
3636
};
3737
pub(self) use super::ob_type::{IsType, ObType};
38-
pub(self) use super::shared::{
39-
get_field_marker, to_json_bytes, BuildSerializer, CombinedSerializer, PydanticSerializer, TypeSerializer,
40-
};
38+
pub(self) use super::shared::{to_json_bytes, BuildSerializer, CombinedSerializer, PydanticSerializer, TypeSerializer};

src/validators/dataclass.rs

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ use crate::errors::{ErrorType, ValError, ValLineError, ValResult};
1010
use crate::input::{GenericArguments, Input};
1111
use crate::lookup_key::LookupKey;
1212
use crate::recursion_guard::RecursionGuard;
13-
use crate::serializers::dataclass_to_dict;
1413
use crate::validators::function::convert_err;
1514

1615
use super::arguments::{json_get, json_slice, py_get, py_slice};
@@ -411,11 +410,12 @@ pub struct DataclassValidator {
411410
strict: bool,
412411
validator: Box<CombinedValidator>,
413412
class: Py<PyType>,
413+
fields: Vec<Py<PyString>>,
414414
post_init: Option<Py<PyString>>,
415415
revalidate: Revalidate,
416416
name: String,
417417
frozen: bool,
418-
slots: Option<Vec<Py<PyString>>>,
418+
slots: bool,
419419
}
420420

421421
impl BuildValidator for DataclassValidator {
@@ -442,21 +442,17 @@ impl BuildValidator for DataclassValidator {
442442
None
443443
};
444444

445-
let slots = match schema.get_as::<&PyList>(intern!(py, "slots"))? {
446-
Some(slots) => {
447-
let slots = slots
448-
.iter()
449-
.map(|s| Ok(s.downcast::<PyString>()?.into_py(py)))
450-
.collect::<PyResult<Vec<_>>>()?;
451-
Some(slots)
452-
}
453-
None => None,
454-
};
445+
let fields = schema
446+
.get_as_req::<&PyList>(intern!(py, "fields"))?
447+
.iter()
448+
.map(|s| Ok(s.downcast::<PyString>()?.into_py(py)))
449+
.collect::<PyResult<Vec<_>>>()?;
455450

456451
Ok(Self {
457452
strict: is_strict(schema, config)?,
458453
validator: Box::new(validator),
459454
class: class.into(),
455+
fields,
460456
post_init,
461457
revalidate: Revalidate::from_str(schema_or_config_same(
462458
schema,
@@ -465,7 +461,7 @@ impl BuildValidator for DataclassValidator {
465461
)?)?,
466462
name,
467463
frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false),
468-
slots,
464+
slots: schema.get_as(intern!(py, "slots"))?.unwrap_or(false),
469465
}
470466
.into())
471467
}
@@ -489,7 +485,7 @@ impl Validator for DataclassValidator {
489485
let class = self.class.as_ref(py);
490486
if let Some(py_input) = input.input_is_instance(class) {
491487
if self.revalidate.should_revalidate(py_input, class) {
492-
let input_dict: &PyAny = dataclass_to_dict(py_input)?;
488+
let input_dict: &PyAny = self.dataclass_to_dict(py, py_input)?;
493489
let val_output = self
494490
.validator
495491
.validate(py, input_dict, extra, definitions, recursion_guard)?;
@@ -530,17 +526,7 @@ impl Validator for DataclassValidator {
530526
return Err(ValError::new(ErrorType::FrozenInstance, field_value));
531527
}
532528

533-
let new_dict = if let Some(ref slots) = self.slots {
534-
let slots_dict = PyDict::new(py);
535-
for slot in slots {
536-
let slot = slot.as_ref(py);
537-
slots_dict.set_item(slot, obj.getattr(slot)?)?;
538-
}
539-
slots_dict
540-
} else {
541-
let dunder_dict: &PyDict = obj.getattr(intern!(py, "__dict__"))?.downcast()?;
542-
dunder_dict.copy()?
543-
};
529+
let new_dict = self.dataclass_to_dict(py, obj)?;
544530

545531
new_dict.set_item(field_name, field_value)?;
546532

@@ -556,8 +542,10 @@ impl Validator for DataclassValidator {
556542

557543
let (dc_dict, _): (&PyDict, PyObject) = val_assignment_result.extract(py)?;
558544

559-
if self.slots.is_some() {
560-
let value = dc_dict.get_item(field_name).unwrap();
545+
if self.slots {
546+
let value = dc_dict
547+
.get_item(field_name)
548+
.ok_or_else(|| PyKeyError::new_err(field_name.to_string()))?;
561549
force_setattr(py, obj, field_name, value)?;
562550
} else {
563551
force_setattr(py, obj, intern!(py, "__dict__"), dc_dict)?;
@@ -613,6 +601,16 @@ impl DataclassValidator {
613601
Ok(self_instance.into_py(py))
614602
}
615603

604+
fn dataclass_to_dict<'py>(&self, py: Python<'py>, dc: &'py PyAny) -> PyResult<&'py PyDict> {
605+
let dict = PyDict::new(py);
606+
607+
for field_name in &self.fields {
608+
let field_name = field_name.as_ref(py);
609+
dict.set_item(field_name, dc.getattr(field_name)?)?;
610+
}
611+
Ok(dict)
612+
}
613+
616614
fn set_dict_call<'s, 'data>(
617615
&'s self,
618616
py: Python<'data>,
@@ -621,7 +619,7 @@ impl DataclassValidator {
621619
input: &'data impl Input<'data>,
622620
) -> ValResult<'data, ()> {
623621
let (dc_dict, post_init_kwargs): (&PyAny, &PyAny) = val_output.extract(py)?;
624-
if self.slots.is_some() {
622+
if self.slots {
625623
let dc_dict: &PyDict = dc_dict.downcast()?;
626624
for (key, value) in dc_dict.iter() {
627625
force_setattr(py, dc, key, value)?;

tests/serializers/test_any.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ class Foo:
439439
core_schema.dataclass_args_schema(
440440
'Foo', [core_schema.dataclass_field(name='a', schema=core_schema.str_schema())]
441441
),
442+
['a'],
442443
)
443444
Foo.__pydantic_serializer__ = SchemaSerializer(schema)
444445

@@ -467,6 +468,7 @@ class Foo:
467468
core_schema.dataclass_args_schema(
468469
'Foo', [core_schema.dataclass_field(name='a', schema=core_schema.str_schema())]
469470
),
471+
['a'],
470472
)
471473
Foo.__pydantic_validator__ = SchemaValidator(schema)
472474
Foo.__pydantic_serializer__ = SchemaSerializer(schema)

tests/serializers/test_dataclasses.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def test_dataclass():
3232
core_schema.dataclass_field(name='b', schema=core_schema.bytes_schema()),
3333
],
3434
),
35+
['a', 'b'],
3536
)
3637
s = SchemaSerializer(schema)
3738
assert s.to_python(Foo(a='hello', b=b'more')) == IsStrictDict(a='hello', b=b'more')
@@ -57,6 +58,7 @@ def test_serialization_exclude():
5758
core_schema.dataclass_field(name='b', schema=core_schema.bytes_schema(), serialization_exclude=True),
5859
],
5960
),
61+
['a', 'b'],
6062
)
6163
s = SchemaSerializer(schema)
6264
assert s.to_python(Foo(a='hello', b=b'more')) == {'a': 'hello'}
@@ -79,6 +81,7 @@ def test_serialization_alias():
7981
core_schema.dataclass_field(name='b', schema=core_schema.bytes_schema(), serialization_alias='BAR'),
8082
],
8183
),
84+
['a', 'b'],
8285
)
8386
s = SchemaSerializer(schema)
8487
assert s.to_python(Foo(a='hello', b=b'more')) == IsStrictDict(a='hello', BAR=b'more')
@@ -111,6 +114,7 @@ def c(self) -> str:
111114
],
112115
computed_fields=[core_schema.computed_field('c', core_schema.str_schema())],
113116
),
117+
['a', 'b'],
114118
)
115119
s = SchemaSerializer(schema)
116120
assert s.to_python(FooProp(a='hello', b=b'more')) == IsStrictDict(a='hello', b=b'more', c='hello more')
@@ -151,7 +155,8 @@ class SubModel(Model):
151155
core_schema.dataclass_field(name='y2', init_only=True, schema=core_schema.str_schema()),
152156
],
153157
),
154-
slots=['x'],
158+
['x', 'x2'],
159+
slots=True,
155160
)
156161
dc = SubModel(x=1, y='a', x2=2, y2='b')
157162
assert dataclasses.asdict(dc) == {'x': 1, 'x2': 2}

tests/test_schema_functions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,13 +255,13 @@ def args(*args, **kwargs):
255255
),
256256
(
257257
core_schema.dataclass_schema,
258-
args(MyDataclass, {'type': 'int'}),
259-
{'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyDataclass},
258+
args(MyDataclass, {'type': 'int'}, ['foobar']),
259+
{'type': 'dataclass', 'schema': {'type': 'int'}, 'fields': ['foobar'], 'cls': MyDataclass},
260260
),
261261
(
262262
core_schema.dataclass_schema,
263-
args(MyDataclass, {'type': 'int'}, slots=['a']),
264-
{'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyDataclass, 'slots': ['a']},
263+
args(MyDataclass, {'type': 'int'}, ['foobar'], slots=True),
264+
{'type': 'dataclass', 'schema': {'type': 'int'}, 'fields': ['foobar'], 'cls': MyDataclass, 'slots': True},
265265
),
266266
]
267267

0 commit comments

Comments
 (0)