Skip to content

Commit dacb9ff

Browse files
committed
benchmark
1 parent 27faf1a commit dacb9ff

File tree

1 file changed

+62
-54
lines changed
  • src/serializers/type_serializers

1 file changed

+62
-54
lines changed

src/serializers/type_serializers/union.rs

Lines changed: 62 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -80,39 +80,7 @@ impl TypeSerializer for UnionSerializer {
8080
exclude: Option<&Bound<'_, PyAny>>,
8181
extra: &Extra,
8282
) -> PyResult<PyObject> {
83-
// try the serializers in left to right order with error_on fallback=true
84-
let mut new_extra = extra.clone();
85-
new_extra.check = SerCheck::Strict;
86-
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();
87-
88-
for comb_serializer in &self.choices {
89-
match comb_serializer.to_python(value, include, exclude, &new_extra) {
90-
Ok(v) => return Ok(v),
91-
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(value.py()) {
92-
true => (),
93-
false => errors.push(err),
94-
},
95-
}
96-
}
97-
if self.retry_with_lax_check() {
98-
new_extra.check = SerCheck::Lax;
99-
for comb_serializer in &self.choices {
100-
match comb_serializer.to_python(value, include, exclude, &new_extra) {
101-
Ok(v) => return Ok(v),
102-
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(value.py()) {
103-
true => (),
104-
false => errors.push(err),
105-
},
106-
}
107-
}
108-
}
109-
110-
for err in &errors {
111-
extra.warnings.custom_warning(err.to_string());
112-
}
113-
114-
extra.warnings.on_fallback_py(self.get_name(), value, extra)?;
115-
infer_to_python(value, include, exclude, extra)
83+
to_python(value, include, exclude, extra, &self.choices, self.get_name())
11684
}
11785

11886
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
@@ -202,10 +170,55 @@ impl TypeSerializer for UnionSerializer {
202170
}
203171
}
204172

205-
#[derive(Debug, Clone)]
173+
fn to_python(
174+
value: &Bound<'_, PyAny>,
175+
include: Option<&Bound<'_, PyAny>>,
176+
exclude: Option<&Bound<'_, PyAny>>,
177+
extra: &Extra,
178+
choices: &[CombinedSerializer],
179+
name: &str,
180+
) -> PyResult<PyObject> {
181+
// try the serializers in left to right order with error_on fallback=true
182+
let mut new_extra = extra.clone();
183+
new_extra.check = SerCheck::Strict;
184+
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();
185+
186+
for comb_serializer in choices.clone() {
187+
match comb_serializer.to_python(value, include, exclude, &new_extra) {
188+
Ok(v) => return Ok(v),
189+
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(value.py()) {
190+
true => (),
191+
false => errors.push(err),
192+
},
193+
}
194+
}
195+
196+
let retry_with_lax_check = choices.clone().into_iter().any(CombinedSerializer::retry_with_lax_check);
197+
if retry_with_lax_check {
198+
new_extra.check = SerCheck::Lax;
199+
for comb_serializer in choices {
200+
match comb_serializer.to_python(value, include, exclude, &new_extra) {
201+
Ok(v) => return Ok(v),
202+
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(value.py()) {
203+
true => (),
204+
false => errors.push(err),
205+
},
206+
}
207+
}
208+
}
209+
210+
for err in &errors {
211+
extra.warnings.custom_warning(err.to_string());
212+
}
213+
214+
extra.warnings.on_fallback_py(name, value, extra)?;
215+
infer_to_python(value, include, exclude, extra)
216+
}
217+
218+
#[derive(Debug)]
206219
pub struct TaggedUnionSerializer {
207220
discriminator: Discriminator,
208-
lookup: HashMap<String, CombinedSerializer>,
221+
lookup: HashMap<String, usize>,
209222
choices: Vec<CombinedSerializer>,
210223
name: String,
211224
}
@@ -221,14 +234,15 @@ impl BuildSerializer for TaggedUnionSerializer {
221234
let py = schema.py();
222235
let discriminator = Discriminator::new(py, &schema.get_as_req(intern!(py, "discriminator"))?)?;
223236

237+
// TODO: guarantee at least 1 choice
224238
let choices_map: Bound<PyDict> = schema.get_as_req(intern!(py, "choices"))?;
225-
let mut lookup: HashMap<String, CombinedSerializer> = HashMap::with_capacity(choices_map.len());
226-
let mut choices: Vec<CombinedSerializer> = Vec::with_capacity(choices_map.len());
239+
let mut lookup = HashMap::with_capacity(choices_map.len());
240+
let mut choices = Vec::with_capacity(choices_map.len());
227241

228-
for (choice_key, choice_schema) in choices_map {
229-
let serializer = CombinedSerializer::build(choice_schema.downcast()?, config, definitions).unwrap();
230-
choices.push(serializer.clone());
231-
lookup.insert(choice_key.to_string(), serializer);
242+
for (idx, (choice_key, choice_schema)) in choices_map.into_iter().enumerate() {
243+
let serializer = CombinedSerializer::build(choice_schema.downcast()?, config, definitions)?;
244+
choices.push(serializer);
245+
lookup.insert(choice_key.to_string(), idx);
232246
}
233247

234248
let descr = choices
@@ -265,13 +279,13 @@ impl TypeSerializer for TaggedUnionSerializer {
265279
if let Some(tag) = self.get_discriminator_value(value) {
266280
let tag_str = tag.to_string();
267281
if let Some(serializer) = self.lookup.get(&tag_str) {
268-
match serializer.to_python(value, include, exclude, &new_extra) {
282+
match self.choices[*serializer].to_python(value, include, exclude, &new_extra) {
269283
Ok(v) => return Ok(v),
270284
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(py) {
271285
true => {
272286
if self.retry_with_lax_check() {
273287
new_extra.check = SerCheck::Lax;
274-
return serializer.to_python(value, include, exclude, &new_extra);
288+
return self.choices[*serializer].to_python(value, include, exclude, &new_extra);
275289
}
276290
}
277291
false => return Err(err),
@@ -280,13 +294,7 @@ impl TypeSerializer for TaggedUnionSerializer {
280294
}
281295
}
282296

283-
let basic_union_ser = UnionSerializer::from_choices(self.choices.clone());
284-
if let Ok(s) = basic_union_ser {
285-
return s.to_python(value, include, exclude, extra);
286-
}
287-
288-
extra.warnings.on_fallback_py(self.get_name(), value, extra)?;
289-
infer_to_python(value, include, exclude, extra)
297+
to_python(value, include, exclude, extra, &self.choices, self.get_name())
290298
}
291299

292300
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
@@ -296,12 +304,12 @@ impl TypeSerializer for TaggedUnionSerializer {
296304
if let Some(tag) = self.get_discriminator_value(key) {
297305
let tag_str = tag.to_string();
298306
if let Some(serializer) = self.lookup.get(&tag_str) {
299-
match serializer.json_key(key, &new_extra) {
307+
match self.choices[*serializer].json_key(key, &new_extra) {
300308
Ok(v) => return Ok(v),
301309
Err(_) => {
302310
if self.retry_with_lax_check() {
303311
new_extra.check = SerCheck::Lax;
304-
return serializer.json_key(key, &new_extra);
312+
return self.choices[*serializer].json_key(key, &new_extra);
305313
}
306314
}
307315
}
@@ -332,12 +340,12 @@ impl TypeSerializer for TaggedUnionSerializer {
332340
if let Some(tag) = self.get_discriminator_value(value) {
333341
let tag_str = tag.to_string();
334342
if let Some(selected_serializer) = self.lookup.get(&tag_str) {
335-
match selected_serializer.to_python(value, include, exclude, &new_extra) {
343+
match self.choices[*selected_serializer].to_python(value, include, exclude, &new_extra) {
336344
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
337345
Err(_) => {
338346
if self.retry_with_lax_check() {
339347
new_extra.check = SerCheck::Lax;
340-
match selected_serializer.to_python(value, include, exclude, &new_extra) {
348+
match self.choices[*selected_serializer].to_python(value, include, exclude, &new_extra) {
341349
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
342350
Err(err) => return Err(py_err_se_err(err)),
343351
}

0 commit comments

Comments
 (0)