Skip to content

Commit 9eae5de

Browse files
committed
properly match dataclasses.fields logic
1 parent 00a9aaf commit 9eae5de

File tree

2 files changed

+51
-6
lines changed

2 files changed

+51
-6
lines changed

src/serializers/shared.rs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::borrow::Cow;
22
use std::fmt::Debug;
33

44
use pyo3::exceptions::PyTypeError;
5+
use pyo3::once_cell::GILOnceCell;
56
use pyo3::prelude::*;
67
use pyo3::types::{PyDict, PySet, PyString};
78
use pyo3::{intern, PyTraverseError, PyVisit};
@@ -349,13 +350,26 @@ pub(super) fn object_to_dict<'py>(value: &'py PyAny, is_model: bool, extra: &Ext
349350
}
350351
}
351352

352-
pub(crate) fn slots_dc_dict(value: &PyAny) -> PyResult<&PyDict> {
353-
let py = value.py();
354-
let dc_fields: &PyDict = value.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?;
353+
static DC_FIELD_MARKER: GILOnceCell<PyObject> = GILOnceCell::new();
354+
355+
pub(crate) fn slots_dc_dict(dc: &PyAny) -> PyResult<&PyDict> {
356+
let py = dc.py();
357+
let dc_fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?;
355358
let dict = PyDict::new(py);
356-
for field in dc_fields.keys() {
357-
let field: &PyString = field.downcast()?;
358-
dict.set_item(field, value.getattr(field)?)?;
359+
360+
// need to match the logic from dataclasses.fields `tuple(f for f in fields.values() if f._field_type is _FIELD)`
361+
let field_type_marker_obj = DC_FIELD_MARKER.get_or_try_init(py, || {
362+
let field_ = py.import("dataclasses")?.getattr("_FIELD")?;
363+
Ok::<PyObject, PyErr>(field_.into_py(py))
364+
})?;
365+
let field_type_marker = field_type_marker_obj.as_ref(py);
366+
367+
for (field_name, field) in dc_fields.iter() {
368+
let field_type = field.getattr(intern!(py, "_field_type"))?;
369+
if field_type.is(field_type_marker) {
370+
let field_name: &PyString = field_name.downcast()?;
371+
dict.set_item(field_name, dc.getattr(field_name)?)?;
372+
}
359373
}
360374
Ok(dict)
361375
}

tests/serializers/test_any.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,24 @@ def __init__(self, a: str, b: bytes):
475475
assert s.to_python(Foo(a='hello', b=b'more'), exclude={'a'}) == IsStrictDict()
476476

477477

478+
def test_dataclass_classvar(any_serializer):
479+
@dataclasses.dataclass(slots=True)
480+
class Foo:
481+
a: int
482+
b: str
483+
c: ClassVar[int] = 1
484+
485+
foo = Foo(1, 'a')
486+
assert any_serializer.to_python(foo) == IsStrictDict(a=1, b='a')
487+
488+
@dataclasses.dataclass(slots=True)
489+
class Foo2(Foo):
490+
pass
491+
492+
foo2 = Foo2(2, 'b')
493+
assert any_serializer.to_python(foo2) == IsStrictDict(a=2, b='b')
494+
495+
478496
@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10')
479497
def test_dataclass_slots(any_serializer):
480498
@dataclasses.dataclass(slots=True)
@@ -491,3 +509,16 @@ class Foo2(Foo):
491509

492510
foo2 = Foo2(2, 'b')
493511
assert any_serializer.to_python(foo2) == IsStrictDict(a=2, b='b')
512+
513+
514+
@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10')
515+
def test_dataclass_slots_init_vars(any_serializer):
516+
@dataclasses.dataclass(slots=True)
517+
class Foo:
518+
a: int
519+
b: str
520+
c: dataclasses.InitVar[int]
521+
d: ClassVar[int] = 42
522+
523+
foo = Foo(1, 'a', 42)
524+
assert any_serializer.to_python(foo) == IsStrictDict(a=1, b='a')

0 commit comments

Comments
 (0)