@@ -2,6 +2,7 @@ use std::borrow::Cow;
2
2
use std:: fmt:: Debug ;
3
3
4
4
use pyo3:: exceptions:: PyTypeError ;
5
+ use pyo3:: once_cell:: GILOnceCell ;
5
6
use pyo3:: prelude:: * ;
6
7
use pyo3:: types:: { PyDict , PySet , PyString } ;
7
8
use pyo3:: { intern, PyTraverseError , PyVisit } ;
@@ -349,13 +350,26 @@ pub(super) fn object_to_dict<'py>(value: &'py PyAny, is_model: bool, extra: &Ext
349
350
}
350
351
}
351
352
352
- pub ( crate ) fn slots_dc_dict ( value : & PyAny ) -> PyResult < & PyDict > {
353
- let py = value. py ( ) ;
354
- let dc_fields: & PyDict = value. getattr ( intern ! ( py, "__dataclass_fields__" ) ) ?. downcast ( ) ?;
353
+ static DC_FIELD_MARKER : GILOnceCell < PyObject > = GILOnceCell :: new ( ) ;
354
+
355
+ pub ( crate ) fn slots_dc_dict ( dc : & PyAny ) -> PyResult < & PyDict > {
356
+ let py = dc. py ( ) ;
357
+ let dc_fields: & PyDict = dc. getattr ( intern ! ( py, "__dataclass_fields__" ) ) ?. downcast ( ) ?;
355
358
let dict = PyDict :: new ( py) ;
356
- for field in dc_fields. keys ( ) {
357
- let field: & PyString = field. downcast ( ) ?;
358
- dict. set_item ( field, value. getattr ( field) ?) ?;
359
+
360
+ // need to match the logic from dataclasses.fields `tuple(f for f in fields.values() if f._field_type is _FIELD)`
361
+ let field_type_marker_obj = DC_FIELD_MARKER . get_or_try_init ( py, || {
362
+ let field_ = py. import ( "dataclasses" ) ?. getattr ( "_FIELD" ) ?;
363
+ Ok :: < PyObject , PyErr > ( field_. into_py ( py) )
364
+ } ) ?;
365
+ let field_type_marker = field_type_marker_obj. as_ref ( py) ;
366
+
367
+ for ( field_name, field) in dc_fields. iter ( ) {
368
+ let field_type = field. getattr ( intern ! ( py, "_field_type" ) ) ?;
369
+ if field_type. is ( field_type_marker) {
370
+ let field_name: & PyString = field_name. downcast ( ) ?;
371
+ dict. set_item ( field_name, dc. getattr ( field_name) ?) ?;
372
+ }
359
373
}
360
374
Ok ( dict)
361
375
}
0 commit comments