Skip to content

fix python GC traversal for validators and serializers #787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ use pyo3::{prelude::*, sync::GILOnceCell};
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;

// parse this first to get access to the contained macro
#[macro_use]
mod py_gc;

mod argument_markers;
mod build_tools;
mod definitions;
Expand Down
70 changes: 70 additions & 0 deletions src/py_gc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use ahash::AHashMap;
use enum_dispatch::enum_dispatch;
use pyo3::{AsPyPointer, Py, PyTraverseError, PyVisit};

/// Trait implemented by types which can be traversed by the Python GC.
#[enum_dispatch]
pub trait PyGcTraverse {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError>;
}

impl<T> PyGcTraverse for Py<T>
where
Py<T>: AsPyPointer,
{
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
visit.call(self)
}
}

impl<T: PyGcTraverse> PyGcTraverse for Vec<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
for item in self {
item.py_gc_traverse(visit)?;
}
Ok(())
}
}

impl<T: PyGcTraverse> PyGcTraverse for AHashMap<String, T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
for item in self.values() {
item.py_gc_traverse(visit)?;
}
Ok(())
}
}
Comment on lines +20 to +36
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like these might even make sense in PyO3 at some point, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, though probably not in this form. In PyO3 I'd like (if it's possible) to do it automatically with a #[derive] macro, however the challenge is how to make most types noops and enable gc traversal for just the ones that matter. I'm sure there's an issue about this in the PyO3 backlog but I can't find it right now.


impl<T: PyGcTraverse> PyGcTraverse for Box<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
T::py_gc_traverse(self, visit)
}
}

impl<T: PyGcTraverse> PyGcTraverse for Option<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
match self {
Some(item) => T::py_gc_traverse(item, visit),
None => Ok(()),
}
}
}

/// A crude alternative to a "derive" macro to help with building PyGcTraverse implementations
macro_rules! impl_py_gc_traverse {
($name:ty { }) => {
impl crate::py_gc::PyGcTraverse for $name {
fn py_gc_traverse(&self, _visit: &pyo3::PyVisit<'_>) -> Result<(), pyo3::PyTraverseError> {
Ok(())
}
}
};
($name:ty { $($fields:ident),* }) => {
impl crate::py_gc::PyGcTraverse for $name {
fn py_gc_traverse(&self, visit: &pyo3::PyVisit<'_>) -> Result<(), pyo3::PyTraverseError> {
$(self.$fields.py_gc_traverse(visit)?;)*
Ok(())
}
}
};
}
13 changes: 12 additions & 1 deletion src/serializers/computed_fields.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyString};
use pyo3::{intern, PyTraverseError, PyVisit};
use serde::ser::SerializeMap;
use serde::Serialize;

use crate::build_tools::py_schema_error_type;
use crate::definitions::DefinitionsBuilder;
use crate::py_gc::PyGcTraverse;
use crate::serializers::filter::SchemaFilter;
use crate::serializers::shared::{BuildSerializer, CombinedSerializer, PydanticSerializer, TypeSerializer};
use crate::tools::SchemaDict;
Expand Down Expand Up @@ -156,6 +157,16 @@ pub(crate) struct ComputedFieldSerializer<'py> {
extra: &'py Extra<'py>,
}

impl_py_gc_traverse!(ComputedField { serializer });

impl PyGcTraverse for ComputedFields {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
self.0.py_gc_traverse(visit)
}
}

impl_py_gc_traverse!(ComputedFieldSerializer<'_> { computed_field });

impl<'py> Serialize for ComputedFieldSerializer<'py> {
fn serialize<S: serde::ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let py = self.model.py();
Expand Down
7 changes: 7 additions & 0 deletions src/serializers/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ pub(super) struct SerField {
pub required: bool,
}

impl_py_gc_traverse!(SerField { serializer });

impl SerField {
pub fn new(
py: Python,
Expand Down Expand Up @@ -142,6 +144,11 @@ macro_rules! option_length {
};
}

impl_py_gc_traverse!(GeneralFieldsSerializer {
fields,
computed_fields
});

impl TypeSerializer for GeneralFieldsSerializer {
fn to_python(
&self,
Expand Down
8 changes: 1 addition & 7 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use pyo3::types::{PyBytes, PyDict};
use pyo3::{PyTraverseError, PyVisit};

use crate::definitions::DefinitionsBuilder;
use crate::py_gc::PyGcTraverse;
use crate::validators::SelfValidator;

use config::SerializationConfig;
Expand Down Expand Up @@ -191,13 +192,6 @@ impl SchemaSerializer {
}
Ok(())
}

fn __clear__(&mut self) {
self.serializer.py_gc_clear();
for slot in &mut self.definitions {
slot.py_gc_clear();
}
}
}

#[allow(clippy::too_many_arguments)]
Expand Down
47 changes: 43 additions & 4 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use serde_json::ser::PrettyFormatter;
use crate::build_tools::py_schema_err;
use crate::build_tools::py_schema_error_type;
use crate::definitions::DefinitionsBuilder;
use crate::py_gc::PyGcTraverse;
use crate::tools::{py_err, SchemaDict};

use super::errors::se_err_py_err;
Expand Down Expand Up @@ -215,12 +216,50 @@ impl BuildSerializer for CombinedSerializer {
}
}

// Implemented by hand because `enum_dispatch` fails with a proc macro compile error =/
impl PyGcTraverse for CombinedSerializer {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
match self {
CombinedSerializer::Function(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::FunctionWrap(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Fields(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::None(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Nullable(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Int(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Bool(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Float(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Str(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Bytes(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Datetime(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::TimeDelta(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Date(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Time(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::List(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Set(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::FrozenSet(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Generator(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Dict(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Model(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Dataclass(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Url(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::MultiHostUrl(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Any(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Format(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::ToString(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::WithDefault(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Json(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::JsonOrPython(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Union(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Literal(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Recursive(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::TuplePositional(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::TupleVariable(inner) => inner.py_gc_traverse(visit),
}
}
}

#[enum_dispatch(CombinedSerializer)]
pub(crate) trait TypeSerializer: Send + Sync + Clone + Debug {
fn py_gc_traverse(&self, _visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(())
}
fn py_gc_clear(&mut self) {}
fn to_python(
&self,
value: &PyAny,
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/type_serializers/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ impl BuildSerializer for AnySerializer {
}
}

impl_py_gc_traverse!(AnySerializer {});

impl TypeSerializer for AnySerializer {
fn to_python(
&self,
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/type_serializers/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ impl BuildSerializer for BytesSerializer {
}
}

impl_py_gc_traverse!(BytesSerializer {});

impl TypeSerializer for BytesSerializer {
fn to_python(
&self,
Expand Down
10 changes: 3 additions & 7 deletions src/serializers/type_serializers/dataclass.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyString, PyType};
use pyo3::{intern, PyTraverseError, PyVisit};
use std::borrow::Cow;

use ahash::AHashMap;
Expand Down Expand Up @@ -121,13 +121,9 @@ impl DataclassSerializer {
}
}

impl TypeSerializer for DataclassSerializer {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
visit.call(&self.class)?;
self.serializer.py_gc_traverse(visit)?;
Ok(())
}
impl_py_gc_traverse!(DataclassSerializer { class, serializer });

impl TypeSerializer for DataclassSerializer {
fn to_python(
&self,
value: &PyAny,
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/type_serializers/datetime_etc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ macro_rules! build_serializer {
}
}

impl_py_gc_traverse!($struct_name {});

impl TypeSerializer for $struct_name {
fn to_python(
&self,
Expand Down
3 changes: 3 additions & 0 deletions src/serializers/type_serializers/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};

use crate::definitions::DefinitionsBuilder;

use crate::tools::SchemaDict;

use super::{py_err_se_err, BuildSerializer, CombinedSerializer, Extra, TypeSerializer};
Expand Down Expand Up @@ -59,6 +60,8 @@ impl BuildSerializer for DefinitionRefSerializer {
}
}

impl_py_gc_traverse!(DefinitionRefSerializer {});

impl TypeSerializer for DefinitionRefSerializer {
fn to_python(
&self,
Expand Down
5 changes: 5 additions & 0 deletions src/serializers/type_serializers/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ impl BuildSerializer for DictSerializer {
}
}

impl_py_gc_traverse!(DictSerializer {
key_serializer,
value_serializer
});

impl TypeSerializer for DictSerializer {
fn to_python(
&self,
Expand Down
4 changes: 4 additions & 0 deletions src/serializers/type_serializers/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ impl FormatSerializer {
}
}

impl_py_gc_traverse!(FormatSerializer { format_func });

impl TypeSerializer for FormatSerializer {
fn to_python(
&self,
Expand Down Expand Up @@ -175,6 +177,8 @@ impl BuildSerializer for ToStringSerializer {
}
}

impl_py_gc_traverse!(ToStringSerializer {});

impl TypeSerializer for ToStringSerializer {
fn to_python(
&self,
Expand Down
12 changes: 12 additions & 0 deletions src/serializers/type_serializers/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,12 @@ macro_rules! function_type_serializer {
};
}

impl_py_gc_traverse!(FunctionPlainSerializer {
func,
return_serializer,
fallback_serializer
});

function_type_serializer!(FunctionPlainSerializer);

fn copy_outer_schema(schema: &PyDict) -> PyResult<&PyDict> {
Expand Down Expand Up @@ -399,6 +405,12 @@ impl FunctionWrapSerializer {
}
}

impl_py_gc_traverse!(FunctionWrapSerializer {
serializer,
func,
return_serializer
});

function_type_serializer!(FunctionWrapSerializer);

#[pyclass(module = "pydantic_core._pydantic_core")]
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/type_serializers/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ impl BuildSerializer for GeneratorSerializer {
}
}

impl_py_gc_traverse!(GeneratorSerializer { item_serializer });

impl TypeSerializer for GeneratorSerializer {
fn to_python(
&self,
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/type_serializers/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ impl BuildSerializer for JsonSerializer {
}
}

impl_py_gc_traverse!(JsonSerializer { serializer });

impl TypeSerializer for JsonSerializer {
fn to_python(
&self,
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/type_serializers/json_or_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ impl BuildSerializer for JsonOrPythonSerializer {
}
}

impl_py_gc_traverse!(JsonOrPythonSerializer { json, python });

impl TypeSerializer for JsonOrPythonSerializer {
fn to_python(
&self,
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/type_serializers/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ impl BuildSerializer for ListSerializer {
}
}

impl_py_gc_traverse!(ListSerializer { item_serializer });

impl TypeSerializer for ListSerializer {
fn to_python(
&self,
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/type_serializers/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ impl LiteralSerializer {
}
}

impl_py_gc_traverse!(LiteralSerializer { expected_py });

impl TypeSerializer for LiteralSerializer {
fn to_python(
&self,
Expand Down
Loading