Skip to content

Commit fdd1e85

Browse files
Adding tagged union serializer 🚀 (#1397)
1 parent 3d8295e commit fdd1e85

File tree

12 files changed

+430
-193
lines changed

12 files changed

+430
-193
lines changed

src/common/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub(crate) mod union;

src/common/union.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
use pyo3::prelude::*;
2+
use pyo3::{PyTraverseError, PyVisit};
3+
4+
use crate::lookup_key::LookupKey;
5+
use crate::py_gc::PyGcTraverse;
6+
7+
#[derive(Debug, Clone)]
8+
pub enum Discriminator {
9+
/// use `LookupKey` to find the tag, same as we do to find values in typed_dict aliases
10+
LookupKey(LookupKey),
11+
/// call a function to find the tag to use
12+
Function(PyObject),
13+
}
14+
15+
impl Discriminator {
16+
pub fn new(py: Python, raw: &Bound<'_, PyAny>) -> PyResult<Self> {
17+
if raw.is_callable() {
18+
return Ok(Self::Function(raw.to_object(py)));
19+
}
20+
21+
let lookup_key = LookupKey::from_py(py, raw, None)?;
22+
Ok(Self::LookupKey(lookup_key))
23+
}
24+
25+
pub fn to_string_py(&self, py: Python) -> PyResult<String> {
26+
match self {
27+
Self::Function(f) => Ok(format!("{}()", f.getattr(py, "__name__")?)),
28+
Self::LookupKey(lookup_key) => Ok(lookup_key.to_string()),
29+
}
30+
}
31+
}
32+
33+
impl PyGcTraverse for Discriminator {
34+
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
35+
match self {
36+
Self::Function(obj) => visit.call(obj)?,
37+
Self::LookupKey(_) => {}
38+
}
39+
Ok(())
40+
}
41+
}
42+
43+
pub(crate) const SMALL_UNION_THRESHOLD: usize = 4;

src/input/input_abstract.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ pub trait ValidatedDict<'py> {
239239
where
240240
Self: 'a;
241241
fn get_item<'k>(&self, key: &'k LookupKey) -> ValResult<Option<(&'k LookupPath, Self::Item<'_>)>>;
242-
fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>>;
243242
// FIXME this is a bit of a leaky abstraction
244243
fn is_py_get_attr(&self) -> bool {
245244
false
@@ -282,9 +281,6 @@ impl<'py> ValidatedDict<'py> for Never {
282281
fn get_item<'k>(&self, _key: &'k LookupKey) -> ValResult<Option<(&'k LookupPath, Self::Item<'_>)>> {
283282
unreachable!()
284283
}
285-
fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> {
286-
unreachable!()
287-
}
288284
fn iterate<'a, R>(
289285
&'a self,
290286
_consumer: impl ConsumeIterator<ValResult<(Self::Key<'a>, Self::Item<'a>)>, Output = R>,

src/input/input_json.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -509,10 +509,6 @@ impl<'py, 'data> ValidatedDict<'py> for &'_ JsonObject<'data> {
509509
key.json_get(self)
510510
}
511511

512-
fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> {
513-
None
514-
}
515-
516512
fn iterate<'a, R>(
517513
&'a self,
518514
consumer: impl ConsumeIterator<ValResult<(Self::Key<'a>, Self::Item<'a>)>, Output = R>,

src/input/input_python.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -817,13 +817,6 @@ impl<'py> ValidatedDict<'py> for GenericPyMapping<'_, 'py> {
817817
matches!(self, Self::GetAttr(..))
818818
}
819819

820-
fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> {
821-
match self {
822-
Self::Dict(dict) => Some(dict),
823-
_ => None,
824-
}
825-
}
826-
827820
fn iterate<'a, R>(
828821
&'a self,
829822
consumer: impl ConsumeIterator<ValResult<(Self::Key<'a>, Self::Item<'a>)>, Output = R>,

src/input/input_string.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,6 @@ impl<'py> ValidatedDict<'py> for StringMappingDict<'py> {
293293
fn get_item<'k>(&self, key: &'k LookupKey) -> ValResult<Option<(&'k LookupPath, Self::Item<'_>)>> {
294294
key.py_get_string_mapping_item(&self.0)
295295
}
296-
fn as_py_dict(&self) -> Option<&Bound<'py, PyDict>> {
297-
None
298-
}
299296
fn iterate<'a, R>(
300297
&'a self,
301298
consumer: impl super::ConsumeIterator<ValResult<(Self::Key<'a>, Self::Item<'a>)>, Output = R>,

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ mod py_gc;
1616

1717
mod argument_markers;
1818
mod build_tools;
19+
mod common;
1920
mod definitions;
2021
mod errors;
2122
mod input;

src/serializers/shared.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ combined_serializer! {
8888
// `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer`
8989
// but aren't actually used for serialization, e.g. their `build` method must return another serializer
9090
find_only: {
91-
super::type_serializers::union::TaggedUnionBuilder;
9291
super::type_serializers::other::ChainBuilder;
9392
super::type_serializers::other::CustomErrorBuilder;
9493
super::type_serializers::other::CallBuilder;
@@ -138,6 +137,7 @@ combined_serializer! {
138137
Json: super::type_serializers::json::JsonSerializer;
139138
JsonOrPython: super::type_serializers::json_or_python::JsonOrPythonSerializer;
140139
Union: super::type_serializers::union::UnionSerializer;
140+
TaggedUnion: super::type_serializers::union::TaggedUnionSerializer;
141141
Literal: super::type_serializers::literal::LiteralSerializer;
142142
Enum: super::type_serializers::enum_::EnumSerializer;
143143
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
@@ -247,6 +247,7 @@ impl PyGcTraverse for CombinedSerializer {
247247
CombinedSerializer::Json(inner) => inner.py_gc_traverse(visit),
248248
CombinedSerializer::JsonOrPython(inner) => inner.py_gc_traverse(visit),
249249
CombinedSerializer::Union(inner) => inner.py_gc_traverse(visit),
250+
CombinedSerializer::TaggedUnion(inner) => inner.py_gc_traverse(visit),
250251
CombinedSerializer::Literal(inner) => inner.py_gc_traverse(visit),
251252
CombinedSerializer::Enum(inner) => inner.py_gc_traverse(visit),
252253
CombinedSerializer::Recursive(inner) => inner.py_gc_traverse(visit),

0 commit comments

Comments
 (0)