Skip to content

Commit a99729a

Browse files
authored
fix memory leak with iterable validation (#1271)
1 parent f537a03 commit a99729a

File tree

3 files changed

+60
-2
lines changed

3 files changed

+60
-2
lines changed

src/input/return_enums.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use serde::{ser::Error, Serialize, Serializer};
1919
use crate::errors::{
2020
py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ToErrorValue, ValError, ValLineError, ValResult,
2121
};
22+
use crate::py_gc::PyGcTraverse;
2223
use crate::tools::{extract_i64, new_py_string, py_err};
2324
use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator};
2425

@@ -327,6 +328,15 @@ pub enum GenericIterator<'data> {
327328
JsonArray(GenericJsonIterator<'data>),
328329
}
329330

331+
impl PyGcTraverse for GenericIterator<'_> {
332+
fn py_gc_traverse(&self, visit: &pyo3::PyVisit<'_>) -> Result<(), pyo3::PyTraverseError> {
333+
if let Self::PyIterator(iter) = self {
334+
iter.py_gc_traverse(visit)?;
335+
}
336+
Ok(())
337+
}
338+
}
339+
330340
impl GenericIterator<'_> {
331341
pub(crate) fn into_static(self) -> GenericIterator<'static> {
332342
match self {
@@ -385,6 +395,8 @@ impl GenericPyIterator {
385395
}
386396
}
387397

398+
impl_py_gc_traverse!(GenericPyIterator { obj, iter });
399+
388400
#[derive(Debug, Clone)]
389401
pub struct GenericJsonIterator<'data> {
390402
array: JsonArray<'data>,

src/validators/generator.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use std::fmt;
22
use std::sync::Arc;
33

4-
use pyo3::prelude::*;
54
use pyo3::types::PyDict;
5+
use pyo3::{prelude::*, PyTraverseError, PyVisit};
66

77
use crate::errors::{ErrorType, LocItem, ValError, ValResult};
88
use crate::input::{BorrowInput, GenericIterator, Input};
9+
use crate::py_gc::PyGcTraverse;
910
use crate::recursion_guard::RecursionState;
1011
use crate::tools::SchemaDict;
1112
use crate::ValidationError;
@@ -201,6 +202,12 @@ impl ValidatorIterator {
201202
fn __str__(&self) -> String {
202203
self.__repr__()
203204
}
205+
206+
fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
207+
self.iterator.py_gc_traverse(&visit)?;
208+
self.validator.py_gc_traverse(&visit)?;
209+
Ok(())
210+
}
204211
}
205212

206213
/// Owned validator wrapper for use in generators in functions, this can be passed back to python

tests/test_garbage_collection.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import gc
22
import platform
3-
from typing import Any
3+
from typing import Any, Iterable
44
from weakref import WeakValueDictionary
55

66
import pytest
@@ -79,3 +79,42 @@ class MyModel(BaseModel):
7979
gc.collect(2)
8080

8181
assert len(cache) == 0
82+
83+
84+
@pytest.mark.xfail(
85+
condition=platform.python_implementation() == 'PyPy', reason='https://foss.heptapod.net/pypy/pypy/-/issues/3899'
86+
)
87+
def test_gc_validator_iterator() -> None:
88+
# test for https://github.com/pydantic/pydantic/issues/9243
89+
class MyModel:
90+
iter: Iterable[int]
91+
92+
v = SchemaValidator(
93+
core_schema.model_schema(
94+
MyModel,
95+
core_schema.model_fields_schema(
96+
{'iter': core_schema.model_field(core_schema.generator_schema(core_schema.int_schema()))}
97+
),
98+
),
99+
)
100+
101+
class MyIterable:
102+
def __iter__(self):
103+
return self
104+
105+
def __next__(self):
106+
raise StopIteration()
107+
108+
cache: 'WeakValueDictionary[int, Any]' = WeakValueDictionary()
109+
110+
for _ in range(10_000):
111+
iterable = MyIterable()
112+
cache[id(iterable)] = iterable
113+
v.validate_python({'iter': iterable})
114+
del iterable
115+
116+
gc.collect(0)
117+
gc.collect(1)
118+
gc.collect(2)
119+
120+
assert len(cache) == 0

0 commit comments

Comments
 (0)