Skip to content

Commit f2a0bb8

Browse files
Fix tagged union serialization warning when using aliases (#1442)
1 parent c462f77 commit f2a0bb8

File tree

3 files changed

+87
-30
lines changed

3 files changed

+87
-30
lines changed

src/lookup_key.rs

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -191,34 +191,10 @@ impl LookupKey {
191191
}
192192
}
193193

194-
pub fn py_get_attr<'py, 's>(
194+
pub fn simple_py_get_attr<'py, 's>(
195195
&'s self,
196196
obj: &Bound<'py, PyAny>,
197-
kwargs: Option<&Bound<'py, PyDict>>,
198-
) -> ValResult<Option<(&'s LookupPath, Bound<'py, PyAny>)>> {
199-
match self._py_get_attr(obj, kwargs) {
200-
Ok(v) => Ok(v),
201-
Err(err) => {
202-
let error = py_err_string(obj.py(), err);
203-
Err(ValError::new(
204-
ErrorType::GetAttributeError { error, context: None },
205-
obj,
206-
))
207-
}
208-
}
209-
}
210-
211-
pub fn _py_get_attr<'py, 's>(
212-
&'s self,
213-
obj: &Bound<'py, PyAny>,
214-
kwargs: Option<&Bound<'py, PyDict>>,
215197
) -> PyResult<Option<(&'s LookupPath, Bound<'py, PyAny>)>> {
216-
if let Some(dict) = kwargs {
217-
if let Ok(Some(item)) = self.py_get_dict_item(dict) {
218-
return Ok(Some(item));
219-
}
220-
}
221-
222198
match self {
223199
Self::Simple { py_key, path, .. } => match py_get_attrs(obj, py_key)? {
224200
Some(value) => Ok(Some((path, value))),
@@ -260,6 +236,29 @@ impl LookupKey {
260236
}
261237
}
262238

239+
pub fn py_get_attr<'py, 's>(
240+
&'s self,
241+
obj: &Bound<'py, PyAny>,
242+
kwargs: Option<&Bound<'py, PyDict>>,
243+
) -> ValResult<Option<(&'s LookupPath, Bound<'py, PyAny>)>> {
244+
if let Some(dict) = kwargs {
245+
if let Ok(Some(item)) = self.py_get_dict_item(dict) {
246+
return Ok(Some(item));
247+
}
248+
}
249+
250+
match self.simple_py_get_attr(obj) {
251+
Ok(v) => Ok(v),
252+
Err(err) => {
253+
let error = py_err_string(obj.py(), err);
254+
Err(ValError::new(
255+
ErrorType::GetAttributeError { error, context: None },
256+
obj,
257+
))
258+
}
259+
}
260+
}
261+
263262
pub fn json_get<'a, 'data, 's>(
264263
&'s self,
265264
dict: &'a JsonObject<'data>,

src/serializers/type_serializers/union.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ use std::borrow::Cow;
88
use crate::build_tools::py_schema_err;
99
use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD};
1010
use crate::definitions::DefinitionsBuilder;
11-
use crate::lookup_key::LookupKey;
1211
use crate::serializers::type_serializers::py_err_se_err;
1312
use crate::tools::{truncate_safe_repr, SchemaDict};
1413
use crate::PydanticSerializationUnexpectedValue;
@@ -438,10 +437,10 @@ impl TaggedUnionSerializer {
438437
fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option<Py<PyAny>> {
439438
let py = value.py();
440439
let discriminator_value = match &self.discriminator {
441-
Discriminator::LookupKey(lookup_key) => match lookup_key {
442-
LookupKey::Simple { py_key, .. } => value.getattr(py_key).ok().map(|obj| obj.to_object(py)),
443-
_ => None,
444-
},
440+
Discriminator::LookupKey(lookup_key) => lookup_key
441+
.simple_py_get_attr(value)
442+
.ok()
443+
.and_then(|opt| opt.map(|(_, bound)| bound.to_object(py))),
445444
Discriminator::Function(func) => func.call1(py, (value,)).ok(),
446445
};
447446
if discriminator_value.is_none() {

tests/serializers/test_union.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,3 +711,62 @@ def test_custom_serializer() -> None:
711711
print(s)
712712
assert s.to_python([{'id': 1}, {'id': 2}]) == [1, 2]
713713
assert s.to_python({'id': 1}) == 1
714+
715+
716+
def test_tagged_union_with_aliases() -> None:
717+
@dataclasses.dataclass
718+
class ModelA:
719+
field: int
720+
tag: Literal['a'] = 'a'
721+
722+
@dataclasses.dataclass
723+
class ModelB:
724+
field: int
725+
tag: Literal['b'] = 'b'
726+
727+
s = SchemaSerializer(
728+
core_schema.tagged_union_schema(
729+
choices={
730+
'a': core_schema.dataclass_schema(
731+
ModelA,
732+
core_schema.dataclass_args_schema(
733+
'ModelA',
734+
[
735+
core_schema.dataclass_field(name='field', schema=core_schema.int_schema()),
736+
core_schema.dataclass_field(
737+
name='tag',
738+
schema=core_schema.literal_schema(['a']),
739+
validation_alias='TAG',
740+
serialization_alias='TAG',
741+
),
742+
],
743+
),
744+
['field', 'tag'],
745+
),
746+
'b': core_schema.dataclass_schema(
747+
ModelB,
748+
core_schema.dataclass_args_schema(
749+
'ModelB',
750+
[
751+
core_schema.dataclass_field(name='field', schema=core_schema.int_schema()),
752+
core_schema.dataclass_field(
753+
name='tag',
754+
schema=core_schema.literal_schema(['b']),
755+
validation_alias='TAG',
756+
serialization_alias='TAG',
757+
),
758+
],
759+
),
760+
['field', 'tag'],
761+
),
762+
},
763+
discriminator=[['tag'], ['TAG']],
764+
)
765+
)
766+
767+
assert 'TaggedUnionSerializer' in repr(s)
768+
769+
model_a = ModelA(field=1)
770+
model_b = ModelB(field=1)
771+
assert s.to_python(model_a) == {'field': 1, 'TAG': 'a'}
772+
assert s.to_python(model_b) == {'field': 1, 'TAG': 'b'}

0 commit comments

Comments
 (0)