@@ -8,8 +8,8 @@ use std::borrow::Cow;
8
8
use crate :: build_tools:: py_schema_err;
9
9
use crate :: common:: union:: { Discriminator , SMALL_UNION_THRESHOLD } ;
10
10
use crate :: definitions:: DefinitionsBuilder ;
11
+ use crate :: serializers:: PydanticSerializationUnexpectedValue ;
11
12
use crate :: tools:: { truncate_safe_repr, SchemaDict } ;
12
- use crate :: PydanticSerializationUnexpectedValue ;
13
13
14
14
use super :: {
15
15
infer_json_key, infer_serialize, infer_to_python, BuildSerializer , CombinedSerializer , Extra , SerCheck ,
@@ -70,22 +70,23 @@ impl UnionSerializer {
70
70
71
71
impl_py_gc_traverse ! ( UnionSerializer { choices } ) ;
72
72
73
- fn to_python (
74
- value : & Bound < ' _ , PyAny > ,
75
- include : Option < & Bound < ' _ , PyAny > > ,
76
- exclude : Option < & Bound < ' _ , PyAny > > ,
73
+ fn union_serialize < S > (
74
+ // if this returns `Ok(Some(v))`, we picked a union variant to serialize,
75
+ // Or `Ok(None)` if we couldn't find a suitable variant to serialize
76
+ // Finally, `Err(err)` if we encountered errors while trying to serialize
77
+ mut selector : impl FnMut ( & CombinedSerializer , & Extra ) -> PyResult < S > ,
77
78
extra : & Extra ,
78
79
choices : & [ CombinedSerializer ] ,
79
80
retry_with_lax_check : bool ,
80
- ) -> PyResult < PyObject > {
81
+ ) -> PyResult < Option < S > > {
81
82
// try the serializers in left to right order with error_on fallback=true
82
83
let mut new_extra = extra. clone ( ) ;
83
84
new_extra. check = SerCheck :: Strict ;
84
85
let mut errors: SmallVec < [ PyErr ; SMALL_UNION_THRESHOLD ] > = SmallVec :: new ( ) ;
85
86
86
87
for comb_serializer in choices {
87
- match comb_serializer . to_python ( value , include , exclude , & new_extra) {
88
- Ok ( v) => return Ok ( v ) ,
88
+ match selector ( comb_serializer , & new_extra) {
89
+ Ok ( v) => return Ok ( Some ( v ) ) ,
89
90
Err ( err) => errors. push ( err) ,
90
91
}
91
92
}
@@ -94,8 +95,8 @@ fn to_python(
94
95
if extra. check != SerCheck :: Strict && retry_with_lax_check {
95
96
new_extra. check = SerCheck :: Lax ;
96
97
for comb_serializer in choices {
97
- if let Ok ( v) = comb_serializer . to_python ( value , include , exclude , & new_extra) {
98
- return Ok ( v ) ;
98
+ if let Ok ( v) = selector ( comb_serializer , & new_extra) {
99
+ return Ok ( Some ( v ) ) ;
99
100
}
100
101
}
101
102
}
@@ -113,94 +114,45 @@ fn to_python(
113
114
return Err ( PydanticSerializationUnexpectedValue :: new_err ( Some ( message) ) ) ;
114
115
}
115
116
116
- infer_to_python ( value , include , exclude , extra )
117
+ Ok ( None )
117
118
}
118
119
119
- fn json_key < ' a > (
120
- key : & ' a Bound < ' _ , PyAny > ,
120
+ fn tagged_union_serialize < S > (
121
+ discriminator_value : Option < Py < PyAny > > ,
122
+ lookup : & HashMap < String , usize > ,
123
+ // if this returns `Ok(v)`, we picked a union variant to serialize, where
124
+ // `S` is intermediate state which can be passed on to the finalizer
125
+ mut selector : impl FnMut ( & CombinedSerializer , & Extra ) -> PyResult < S > ,
121
126
extra : & Extra ,
122
127
choices : & [ CombinedSerializer ] ,
123
128
retry_with_lax_check : bool ,
124
- ) -> PyResult < Cow < ' a , str > > {
129
+ ) -> PyResult < Option < S > > {
125
130
let mut new_extra = extra. clone ( ) ;
126
131
new_extra. check = SerCheck :: Strict ;
127
- let mut errors: SmallVec < [ PyErr ; SMALL_UNION_THRESHOLD ] > = SmallVec :: new ( ) ;
128
-
129
- for comb_serializer in choices {
130
- match comb_serializer. json_key ( key, & new_extra) {
131
- Ok ( v) => return Ok ( v) ,
132
- Err ( err) => errors. push ( err) ,
133
- }
134
- }
135
132
136
- // If extra.check is SerCheck::Strict, we're in a nested union
137
- if extra. check != SerCheck :: Strict && retry_with_lax_check {
138
- new_extra. check = SerCheck :: Lax ;
139
- for comb_serializer in choices {
140
- if let Ok ( v) = comb_serializer. json_key ( key, & new_extra) {
141
- return Ok ( v) ;
142
- }
143
- }
144
- }
145
-
146
- // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
147
- if extra. check == SerCheck :: None {
148
- for err in & errors {
149
- extra. warnings . custom_warning ( err. to_string ( ) ) ;
150
- }
151
- }
152
- // Otherwise, if we've encountered errors, return them to the parent union, which should take
153
- // care of the formatting for us
154
- else if !errors. is_empty ( ) {
155
- let message = errors. iter ( ) . map ( ToString :: to_string) . collect :: < Vec < _ > > ( ) . join ( "\n " ) ;
156
- return Err ( PydanticSerializationUnexpectedValue :: new_err ( Some ( message) ) ) ;
157
- }
158
- infer_json_key ( key, extra)
159
- }
160
-
161
- #[ allow( clippy:: too_many_arguments) ]
162
- fn serde_serialize < S : serde:: ser:: Serializer > (
163
- value : & Bound < ' _ , PyAny > ,
164
- serializer : S ,
165
- include : Option < & Bound < ' _ , PyAny > > ,
166
- exclude : Option < & Bound < ' _ , PyAny > > ,
167
- extra : & Extra ,
168
- choices : & [ CombinedSerializer ] ,
169
- retry_with_lax_check : bool ,
170
- ) -> Result < S :: Ok , S :: Error > {
171
- let py = value. py ( ) ;
172
- let mut new_extra = extra. clone ( ) ;
173
- new_extra. check = SerCheck :: Strict ;
174
- let mut errors: SmallVec < [ PyErr ; SMALL_UNION_THRESHOLD ] > = SmallVec :: new ( ) ;
175
-
176
- for comb_serializer in choices {
177
- match comb_serializer. to_python ( value, include, exclude, & new_extra) {
178
- Ok ( v) => return infer_serialize ( v. bind ( py) , serializer, None , None , extra) ,
179
- Err ( err) => errors. push ( err) ,
180
- }
181
- }
182
-
183
- // If extra.check is SerCheck::Strict, we're in a nested union
184
- if extra. check != SerCheck :: Strict && retry_with_lax_check {
185
- new_extra. check = SerCheck :: Lax ;
186
- for comb_serializer in choices {
187
- if let Ok ( v) = comb_serializer. to_python ( value, include, exclude, & new_extra) {
188
- return infer_serialize ( v. bind ( py) , serializer, None , None , extra) ;
133
+ if let Some ( tag) = discriminator_value {
134
+ let tag_str = tag. to_string ( ) ;
135
+ if let Some ( & serializer_index) = lookup. get ( & tag_str) {
136
+ let selected_serializer = & choices[ serializer_index] ;
137
+
138
+ match selector ( selected_serializer, & new_extra) {
139
+ Ok ( v) => return Ok ( Some ( v) ) ,
140
+ Err ( _) => {
141
+ if retry_with_lax_check {
142
+ new_extra. check = SerCheck :: Lax ;
143
+ if let Ok ( v) = selector ( selected_serializer, & new_extra) {
144
+ return Ok ( Some ( v) ) ;
145
+ }
146
+ }
147
+ }
189
148
}
190
149
}
191
150
}
192
151
193
- // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
194
- if extra. check == SerCheck :: None {
195
- for err in & errors {
196
- extra. warnings . custom_warning ( err. to_string ( ) ) ;
197
- }
198
- } else {
199
- // NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors
200
- // will have to be returned here
201
- }
202
-
203
- infer_serialize ( value, serializer, include, exclude, extra)
152
+ // if we haven't returned at this point, we should fallback to the union serializer
153
+ // which preserves the historical expectation that we do our best with serialization
154
+ // even if that means we resort to inference
155
+ union_serialize ( selector, extra, choices, retry_with_lax_check)
204
156
}
205
157
206
158
impl TypeSerializer for UnionSerializer {
@@ -211,18 +163,23 @@ impl TypeSerializer for UnionSerializer {
211
163
exclude : Option < & Bound < ' _ , PyAny > > ,
212
164
extra : & Extra ,
213
165
) -> PyResult < PyObject > {
214
- to_python (
215
- value,
216
- include,
217
- exclude,
166
+ union_serialize (
167
+ |comb_serializer, new_extra| comb_serializer. to_python ( value, include, exclude, new_extra) ,
218
168
extra,
219
169
& self . choices ,
220
170
self . retry_with_lax_check ( ) ,
221
- )
171
+ ) ?
172
+ . map_or_else ( || infer_to_python ( value, include, exclude, extra) , Ok )
222
173
}
223
174
224
175
fn json_key < ' a > ( & self , key : & ' a Bound < ' _ , PyAny > , extra : & Extra ) -> PyResult < Cow < ' a , str > > {
225
- json_key ( key, extra, & self . choices , self . retry_with_lax_check ( ) )
176
+ union_serialize (
177
+ |comb_serializer, new_extra| comb_serializer. json_key ( key, new_extra) ,
178
+ extra,
179
+ & self . choices ,
180
+ self . retry_with_lax_check ( ) ,
181
+ ) ?
182
+ . map_or_else ( || infer_json_key ( key, extra) , Ok )
226
183
}
227
184
228
185
fn serde_serialize < S : serde:: ser:: Serializer > (
@@ -233,15 +190,16 @@ impl TypeSerializer for UnionSerializer {
233
190
exclude : Option < & Bound < ' _ , PyAny > > ,
234
191
extra : & Extra ,
235
192
) -> Result < S :: Ok , S :: Error > {
236
- serde_serialize (
237
- value,
238
- serializer,
239
- include,
240
- exclude,
193
+ match union_serialize (
194
+ |comb_serializer, new_extra| comb_serializer. to_python ( value, include, exclude, new_extra) ,
241
195
extra,
242
196
& self . choices ,
243
197
self . retry_with_lax_check ( ) ,
244
- )
198
+ ) {
199
+ Ok ( Some ( v) ) => return infer_serialize ( v. bind ( value. py ( ) ) , serializer, None , None , extra) ,
200
+ Ok ( None ) => infer_serialize ( value, serializer, include, exclude, extra) ,
201
+ Err ( err) => Err ( serde:: ser:: Error :: custom ( err. to_string ( ) ) ) ,
202
+ }
245
203
}
246
204
247
205
fn get_name ( & self ) -> & str {
@@ -309,62 +267,29 @@ impl TypeSerializer for TaggedUnionSerializer {
309
267
exclude : Option < & Bound < ' _ , PyAny > > ,
310
268
extra : & Extra ,
311
269
) -> PyResult < PyObject > {
312
- let mut new_extra = extra. clone ( ) ;
313
- new_extra. check = SerCheck :: Strict ;
314
-
315
- if let Some ( tag) = self . get_discriminator_value ( value, extra) {
316
- let tag_str = tag. to_string ( ) ;
317
- if let Some ( & serializer_index) = self . lookup . get ( & tag_str) {
318
- let serializer = & self . choices [ serializer_index] ;
319
-
320
- match serializer. to_python ( value, include, exclude, & new_extra) {
321
- Ok ( v) => return Ok ( v) ,
322
- Err ( _) => {
323
- if self . retry_with_lax_check ( ) {
324
- new_extra. check = SerCheck :: Lax ;
325
- if let Ok ( v) = serializer. to_python ( value, include, exclude, & new_extra) {
326
- return Ok ( v) ;
327
- }
328
- }
329
- }
330
- }
331
- }
332
- }
333
-
334
- to_python (
335
- value,
336
- include,
337
- exclude,
270
+ tagged_union_serialize (
271
+ self . get_discriminator_value ( value, extra) ,
272
+ & self . lookup ,
273
+ |comb_serializer : & CombinedSerializer , new_extra : & Extra | {
274
+ comb_serializer. to_python ( value, include, exclude, new_extra)
275
+ } ,
338
276
extra,
339
277
& self . choices ,
340
278
self . retry_with_lax_check ( ) ,
341
- )
279
+ ) ?
280
+ . map_or_else ( || infer_to_python ( value, include, exclude, extra) , Ok )
342
281
}
343
282
344
283
fn json_key < ' a > ( & self , key : & ' a Bound < ' _ , PyAny > , extra : & Extra ) -> PyResult < Cow < ' a , str > > {
345
- let mut new_extra = extra. clone ( ) ;
346
- new_extra. check = SerCheck :: Strict ;
347
-
348
- if let Some ( tag) = self . get_discriminator_value ( key, extra) {
349
- let tag_str = tag. to_string ( ) ;
350
- if let Some ( & serializer_index) = self . lookup . get ( & tag_str) {
351
- let serializer = & self . choices [ serializer_index] ;
352
-
353
- match serializer. json_key ( key, & new_extra) {
354
- Ok ( v) => return Ok ( v) ,
355
- Err ( _) => {
356
- if self . retry_with_lax_check ( ) {
357
- new_extra. check = SerCheck :: Lax ;
358
- if let Ok ( v) = serializer. json_key ( key, & new_extra) {
359
- return Ok ( v) ;
360
- }
361
- }
362
- }
363
- }
364
- }
365
- }
366
-
367
- json_key ( key, extra, & self . choices , self . retry_with_lax_check ( ) )
284
+ tagged_union_serialize (
285
+ self . get_discriminator_value ( key, extra) ,
286
+ & self . lookup ,
287
+ |comb_serializer : & CombinedSerializer , new_extra : & Extra | comb_serializer. json_key ( key, new_extra) ,
288
+ extra,
289
+ & self . choices ,
290
+ self . retry_with_lax_check ( ) ,
291
+ ) ?
292
+ . map_or_else ( || infer_json_key ( key, extra) , Ok )
368
293
}
369
294
370
295
fn serde_serialize < S : serde:: ser:: Serializer > (
@@ -375,38 +300,20 @@ impl TypeSerializer for TaggedUnionSerializer {
375
300
exclude : Option < & Bound < ' _ , PyAny > > ,
376
301
extra : & Extra ,
377
302
) -> Result < S :: Ok , S :: Error > {
378
- let py = value. py ( ) ;
379
- let mut new_extra = extra. clone ( ) ;
380
- new_extra. check = SerCheck :: Strict ;
381
-
382
- if let Some ( tag) = self . get_discriminator_value ( value, extra) {
383
- let tag_str = tag. to_string ( ) ;
384
- if let Some ( & serializer_index) = self . lookup . get ( & tag_str) {
385
- let selected_serializer = & self . choices [ serializer_index] ;
386
-
387
- match selected_serializer. to_python ( value, include, exclude, & new_extra) {
388
- Ok ( v) => return infer_serialize ( v. bind ( py) , serializer, None , None , extra) ,
389
- Err ( _) => {
390
- if self . retry_with_lax_check ( ) {
391
- new_extra. check = SerCheck :: Lax ;
392
- if let Ok ( v) = selected_serializer. to_python ( value, include, exclude, & new_extra) {
393
- return infer_serialize ( v. bind ( py) , serializer, None , None , extra) ;
394
- }
395
- }
396
- }
397
- }
398
- }
399
- }
400
-
401
- serde_serialize (
402
- value,
403
- serializer,
404
- include,
405
- exclude,
303
+ match tagged_union_serialize (
304
+ None ,
305
+ & self . lookup ,
306
+ |comb_serializer : & CombinedSerializer , new_extra : & Extra | {
307
+ comb_serializer. to_python ( value, include, exclude, new_extra)
308
+ } ,
406
309
extra,
407
310
& self . choices ,
408
311
self . retry_with_lax_check ( ) ,
409
- )
312
+ ) {
313
+ Ok ( Some ( v) ) => return infer_serialize ( v. bind ( value. py ( ) ) , serializer, None , None , extra) ,
314
+ Ok ( None ) => infer_serialize ( value, serializer, include, exclude, extra) ,
315
+ Err ( err) => Err ( serde:: ser:: Error :: custom ( err. to_string ( ) ) ) ,
316
+ }
410
317
}
411
318
412
319
fn get_name ( & self ) -> & str {
0 commit comments