Skip to content

Commit c325609

Browse files
committed
final cleanup, thanks @davidhewitt
1 parent 7f2b1da commit c325609

File tree

1 file changed

+50
-32
lines changed
  • src/serializers/type_serializers

1 file changed

+50
-32
lines changed

src/serializers/type_serializers/union.rs

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ fn to_python(
8686
new_extra.check = SerCheck::Strict;
8787
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();
8888

89-
for comb_serializer in choices.clone() {
89+
for comb_serializer in choices {
9090
match comb_serializer.to_python(value, include, exclude, &new_extra) {
9191
Ok(v) => return Ok(v),
9292
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(value.py()) {
@@ -117,18 +117,18 @@ fn to_python(
117117
infer_to_python(value, include, exclude, extra)
118118
}
119119

120-
fn json_key(
121-
key: &Bound<'_, PyAny>,
120+
fn json_key<'a>(
121+
key: &'a Bound<'_, PyAny>,
122122
extra: &Extra,
123123
choices: &[CombinedSerializer],
124124
name: &str,
125125
retry_with_lax_check: bool,
126-
) -> PyResult<Cow<str>> {
126+
) -> PyResult<Cow<'a, str>> {
127127
let mut new_extra = extra.clone();
128128
new_extra.check = SerCheck::Strict;
129129
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();
130130

131-
for comb_serializer in choices.clone() {
131+
for comb_serializer in choices {
132132
match comb_serializer.json_key(key, &new_extra) {
133133
Ok(v) => return Ok(v),
134134
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(key.py()) {
@@ -159,6 +159,7 @@ fn json_key(
159159
infer_json_key(key, extra)
160160
}
161161

162+
#[allow(clippy::too_many_arguments)]
162163
fn serde_serialize<S: serde::ser::Serializer>(
163164
value: &Bound<'_, PyAny>,
164165
serializer: S,
@@ -174,7 +175,7 @@ fn serde_serialize<S: serde::ser::Serializer>(
174175
new_extra.check = SerCheck::Strict;
175176
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();
176177

177-
for comb_serializer in choices.clone() {
178+
for comb_serializer in choices {
178179
match comb_serializer.to_python(value, include, exclude, &new_extra) {
179180
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
180181
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(py) {
@@ -303,7 +304,7 @@ impl BuildSerializer for TaggedUnionSerializer {
303304
}
304305
}
305306

306-
impl_py_gc_traverse!(TaggedUnionSerializer { discriminator, lookup });
307+
impl_py_gc_traverse!(TaggedUnionSerializer { discriminator, choices });
307308

308309
impl TypeSerializer for TaggedUnionSerializer {
309310
fn to_python(
@@ -318,16 +319,18 @@ impl TypeSerializer for TaggedUnionSerializer {
318319
let mut new_extra = extra.clone();
319320
new_extra.check = SerCheck::Strict;
320321

321-
if let Some(tag) = self.get_discriminator_value(value) {
322+
if let Some(tag) = self.get_discriminator_value(value, extra) {
322323
let tag_str = tag.to_string();
323-
if let Some(serializer) = self.lookup.get(&tag_str) {
324-
match self.choices[*serializer].to_python(value, include, exclude, &new_extra) {
324+
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
325+
let serializer = &self.choices[serializer_index];
326+
327+
match serializer.to_python(value, include, exclude, &new_extra) {
325328
Ok(v) => return Ok(v),
326329
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(py) {
327330
true => {
328331
if self.retry_with_lax_check() {
329332
new_extra.check = SerCheck::Lax;
330-
return self.choices[*serializer].to_python(value, include, exclude, &new_extra);
333+
return serializer.to_python(value, include, exclude, &new_extra);
331334
}
332335
}
333336
false => return Err(err),
@@ -348,20 +351,26 @@ impl TypeSerializer for TaggedUnionSerializer {
348351
}
349352

350353
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
354+
let py = key.py();
351355
let mut new_extra = extra.clone();
352356
new_extra.check = SerCheck::Strict;
353357

354-
if let Some(tag) = self.get_discriminator_value(key) {
358+
if let Some(tag) = self.get_discriminator_value(key, extra) {
355359
let tag_str = tag.to_string();
356-
if let Some(serializer) = self.lookup.get(&tag_str) {
357-
match self.choices[*serializer].json_key(key, &new_extra) {
360+
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
361+
let serializer = &self.choices[serializer_index];
362+
363+
match serializer.json_key(key, &new_extra) {
358364
Ok(v) => return Ok(v),
359-
Err(_) => {
360-
if self.retry_with_lax_check() {
361-
new_extra.check = SerCheck::Lax;
362-
return self.choices[*serializer].json_key(key, &new_extra);
365+
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(py) {
366+
true => {
367+
if self.retry_with_lax_check() {
368+
new_extra.check = SerCheck::Lax;
369+
return serializer.json_key(key, &new_extra);
370+
}
363371
}
364-
}
372+
false => return Err(err),
373+
},
365374
}
366375
}
367376
}
@@ -381,20 +390,25 @@ impl TypeSerializer for TaggedUnionSerializer {
381390
let mut new_extra = extra.clone();
382391
new_extra.check = SerCheck::Strict;
383392

384-
if let Some(tag) = self.get_discriminator_value(value) {
393+
if let Some(tag) = self.get_discriminator_value(value, extra) {
385394
let tag_str = tag.to_string();
386-
if let Some(selected_serializer) = self.lookup.get(&tag_str) {
387-
match self.choices[*selected_serializer].to_python(value, include, exclude, &new_extra) {
395+
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
396+
let selected_serializer = &self.choices[serializer_index];
397+
398+
match selected_serializer.to_python(value, include, exclude, &new_extra) {
388399
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
389-
Err(_) => {
390-
if self.retry_with_lax_check() {
391-
new_extra.check = SerCheck::Lax;
392-
match self.choices[*selected_serializer].to_python(value, include, exclude, &new_extra) {
393-
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
394-
Err(err) => return Err(py_err_se_err(err)),
400+
Err(err) => match err.is_instance_of::<PydanticSerializationUnexpectedValue>(py) {
401+
true => {
402+
if self.retry_with_lax_check() {
403+
new_extra.check = SerCheck::Lax;
404+
match selected_serializer.to_python(value, include, exclude, &new_extra) {
405+
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
406+
Err(err) => return Err(py_err_se_err(err)),
407+
}
395408
}
396409
}
397-
}
410+
false => return Err(py_err_se_err(err)),
411+
},
398412
}
399413
}
400414
}
@@ -421,7 +435,7 @@ impl TypeSerializer for TaggedUnionSerializer {
421435
}
422436

423437
impl TaggedUnionSerializer {
424-
fn get_discriminator_value(&self, value: &Bound<'_, PyAny>) -> Option<Py<PyAny>> {
438+
fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option<Py<PyAny>> {
425439
let py = value.py();
426440
let discriminator_value = match &self.discriminator {
427441
Discriminator::LookupKey(lookup_key) => match lookup_key {
@@ -431,8 +445,12 @@ impl TaggedUnionSerializer {
431445
Discriminator::Function(func) => func.call1(py, (value,)).ok(),
432446
};
433447
if discriminator_value.is_none() {
434-
// warn if the discriminator value is not found
448+
extra.warnings.custom_warning(
449+
format!(
450+
"Failed to get discriminator value for tagged union serialization for {value} - defaulting to left to right union serialization."
451+
)
452+
);
435453
}
436-
return discriminator_value;
454+
discriminator_value
437455
}
438456
}

0 commit comments

Comments
 (0)