@@ -62,24 +62,6 @@ impl SerField {
62
62
}
63
63
Cow :: Borrowed ( key_str)
64
64
}
65
-
66
- pub fn to_python (
67
- & self ,
68
- output_dict : & PyDict ,
69
- value : & PyAny ,
70
- next_include : Option < & PyAny > ,
71
- next_exclude : Option < & PyAny > ,
72
- extra : & Extra ,
73
- ) -> PyResult < ( ) > {
74
- if let Some ( ref serializer) = self . serializer {
75
- if !exclude_default ( value, extra, serializer) ? {
76
- let value = serializer. to_python ( value, next_include, next_exclude, extra) ?;
77
- let output_key = self . get_key_py ( output_dict. py ( ) , extra) ;
78
- output_dict. set_item ( output_key, value) ?;
79
- }
80
- }
81
- Ok ( ( ) )
82
- }
83
65
}
84
66
85
67
fn exclude_default ( value : & PyAny , extra : & Extra , serializer : & CombinedSerializer ) -> PyResult < bool > {
@@ -93,12 +75,22 @@ fn exclude_default(value: &PyAny, extra: &Extra, serializer: &CombinedSerializer
93
75
Ok ( false )
94
76
}
95
77
78
+ #[ derive( Debug , Clone ) ]
79
+ pub ( super ) enum FieldsMode {
80
+ // typeddict with no extra items
81
+ SimpleDict ,
82
+ // a model - get `__dict__` and `__pydantic_extra__` - `GeneralFieldsSerializer` will get a tuple
83
+ ModelExtra ,
84
+ // typeddict with extra items - one dict with extra items
85
+ TypedDictAllow ,
86
+ }
87
+
96
88
/// General purpose serializer for fields - used by dataclasses, models and typed_dicts
97
89
#[ derive( Debug , Clone ) ]
98
90
pub struct GeneralFieldsSerializer {
99
91
fields : AHashMap < String , SerField > ,
100
92
computed_fields : Option < ComputedFields > ,
101
- include_extra : bool ,
93
+ mode : FieldsMode ,
102
94
// isize because we look up filter via `.hash()` which returns an isize
103
95
filter : SchemaFilter < isize > ,
104
96
required_fields : usize ,
@@ -107,26 +99,39 @@ pub struct GeneralFieldsSerializer {
107
99
impl GeneralFieldsSerializer {
108
100
pub ( super ) fn new (
109
101
fields : AHashMap < String , SerField > ,
110
- include_extra : bool ,
102
+ mode : FieldsMode ,
111
103
computed_fields : Option < ComputedFields > ,
112
104
) -> Self {
113
105
let required_fields = fields. values ( ) . filter ( |f| f. required ) . count ( ) ;
114
106
Self {
115
107
fields,
116
- include_extra ,
108
+ mode ,
117
109
filter : SchemaFilter :: default ( ) ,
118
110
computed_fields,
119
111
required_fields,
120
112
}
121
113
}
122
114
123
115
fn extract_dicts < ' a > ( & self , value : & ' a PyAny ) -> Option < ( & ' a PyDict , Option < & ' a PyDict > ) > {
124
- if let Ok ( main_dict) = value. downcast :: < PyDict > ( ) {
125
- Some ( ( main_dict, None ) )
126
- } else if let Ok ( ( main_dict, extra_dict) ) = value. extract :: < ( & PyDict , & PyDict ) > ( ) {
127
- Some ( ( main_dict, Some ( extra_dict) ) )
128
- } else {
129
- None
116
+ match self . mode {
117
+ FieldsMode :: ModelExtra => {
118
+ if let Ok ( ( main_dict, extra_dict) ) = value. extract :: < ( & PyDict , & PyAny ) > ( ) {
119
+ if let Ok ( extra_dict) = extra_dict. downcast :: < PyDict > ( ) {
120
+ Some ( ( main_dict, Some ( extra_dict) ) )
121
+ } else {
122
+ Some ( ( main_dict, None ) )
123
+ }
124
+ } else {
125
+ None
126
+ }
127
+ }
128
+ _ => {
129
+ if let Ok ( main_dict) = value. downcast :: < PyDict > ( ) {
130
+ Some ( ( main_dict, None ) )
131
+ } else {
132
+ None
133
+ }
134
+ }
130
135
}
131
136
}
132
137
}
@@ -163,32 +168,39 @@ impl TypeSerializer for GeneralFieldsSerializer {
163
168
return infer_to_python ( value, include, exclude, & td_extra) ;
164
169
} ;
165
170
166
- // NOTE! we maintain the order of the input dict assuming that's right
167
171
let output_dict = PyDict :: new ( py) ;
168
172
let mut used_req_fields: usize = 0 ;
169
173
174
+ // NOTE! we maintain the order of the input dict assuming that's right
170
175
for ( key, value) in main_dict {
176
+ let key_str = key_str ( key) ?;
177
+ let op_field = self . fields . get ( key_str) ;
171
178
if extra. exclude_none && value. is_none ( ) {
179
+ if let Some ( field) = op_field {
180
+ if field. required {
181
+ used_req_fields += 1 ;
182
+ }
183
+ }
172
184
continue ;
173
185
}
186
+ let extra = Extra {
187
+ field_name : Some ( key_str) ,
188
+ ..td_extra
189
+ } ;
174
190
if let Some ( ( next_include, next_exclude) ) = self . filter . key_filter ( key, include, exclude) ? {
175
- let extra = Extra {
176
- field_name : Some ( key. extract ( ) ?) ,
177
- ..td_extra
178
- } ;
179
- if let Ok ( key_py_str) = key. downcast :: < PyString > ( ) {
180
- let key_str = key_py_str. to_str ( ) ?;
181
- if let Some ( field) = self . fields . get ( key_str) {
182
- field. to_python ( output_dict, value, next_include, next_exclude, & extra) ?;
183
-
184
- if field. required {
185
- used_req_fields += 1 ;
191
+ if let Some ( field) = op_field {
192
+ if let Some ( ref serializer) = field. serializer {
193
+ if !exclude_default ( value, & extra, serializer) ? {
194
+ let value = serializer. to_python ( value, next_include, next_exclude, & extra) ?;
195
+ let output_key = field. get_key_py ( output_dict. py ( ) , & extra) ;
196
+ output_dict. set_item ( output_key, value) ?;
186
197
}
187
- continue ;
188
198
}
189
- }
190
- if self . include_extra {
191
- // TODO test this
199
+
200
+ if field. required {
201
+ used_req_fields += 1 ;
202
+ }
203
+ } else if matches ! ( self . mode, FieldsMode :: TypedDictAllow ) {
192
204
let value = infer_to_python ( value, next_include, next_exclude, & extra) ?;
193
205
output_dict. set_item ( key, value) ?;
194
206
} else if extra. check . enabled ( ) {
@@ -242,40 +254,37 @@ impl TypeSerializer for GeneralFieldsSerializer {
242
254
model : extra. model . map_or_else ( || Some ( value) , Some ) ,
243
255
..* extra
244
256
} ;
245
- let expected_len = match self . include_extra {
246
- true => main_dict. len ( ) + option_length ! ( self . computed_fields) ,
247
- false => self . fields . len ( ) + option_length ! ( extra_dict) + option_length ! ( self . computed_fields) ,
257
+ let expected_len = match self . mode {
258
+ FieldsMode :: TypedDictAllow => main_dict. len ( ) + option_length ! ( self . computed_fields) ,
259
+ _ => self . fields . len ( ) + option_length ! ( extra_dict) + option_length ! ( self . computed_fields) ,
248
260
} ;
249
261
// NOTE! As above, we maintain the order of the input dict assuming that's right
250
262
// we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used
251
263
let mut map = serializer. serialize_map ( Some ( expected_len) ) ?;
252
264
253
265
for ( key, value) in main_dict {
254
- let extra = Extra {
255
- field_name : Some ( key. extract ( ) . map_err ( py_err_se_err) ?) ,
256
- ..td_extra
257
- } ;
258
266
if extra. exclude_none && value. is_none ( ) {
259
267
continue ;
260
268
}
269
+ let key_str = key_str ( key) . map_err ( py_err_se_err) ?;
270
+ let extra = Extra {
271
+ field_name : Some ( key_str) ,
272
+ ..td_extra
273
+ } ;
274
+
261
275
let filter = self . filter . key_filter ( key, include, exclude) . map_err ( py_err_se_err) ?;
262
276
if let Some ( ( next_include, next_exclude) ) = filter {
263
- if let Ok ( key_py_str) = key. downcast :: < PyString > ( ) {
264
- let key_str = key_py_str. to_str ( ) . map_err ( py_err_se_err) ?;
265
- if let Some ( field) = self . fields . get ( key_str) {
266
- if let Some ( ref serializer) = field. serializer {
267
- if !exclude_default ( value, & extra, serializer) . map_err ( py_err_se_err) ? {
268
- let s = PydanticSerializer :: new ( value, serializer, next_include, next_exclude, & extra) ;
269
- let output_key = field. get_key_json ( key_str, & extra) ;
270
- map. serialize_entry ( & output_key, & s) ?;
271
- }
272
- continue ;
277
+ if let Some ( field) = self . fields . get ( key_str) {
278
+ if let Some ( ref serializer) = field. serializer {
279
+ if !exclude_default ( value, & extra, serializer) . map_err ( py_err_se_err) ? {
280
+ let s = PydanticSerializer :: new ( value, serializer, next_include, next_exclude, & extra) ;
281
+ let output_key = field. get_key_json ( key_str, & extra) ;
282
+ map. serialize_entry ( & output_key, & s) ?;
273
283
}
274
284
}
275
- }
276
- if self . include_extra {
277
- let s = SerializeInfer :: new ( value, include, exclude, & extra) ;
285
+ } else if matches ! ( self . mode, FieldsMode :: TypedDictAllow ) {
278
286
let output_key = infer_json_key ( key, & extra) . map_err ( py_err_se_err) ?;
287
+ let s = SerializeInfer :: new ( value, next_include, next_exclude, & extra) ;
279
288
map. serialize_entry ( & output_key, & s) ?
280
289
}
281
290
}
@@ -303,3 +312,7 @@ impl TypeSerializer for GeneralFieldsSerializer {
303
312
"fields"
304
313
}
305
314
}
315
+
316
+ fn key_str ( key : & PyAny ) -> PyResult < & str > {
317
+ key. downcast :: < PyString > ( ) ?. to_str ( )
318
+ }
0 commit comments