Skip to content

Commit 97c286e

Browse files
committed
fix segfaults on tests in debug builds for PyPy
1 parent d08d269 commit 97c286e

File tree

8 files changed

+51
-38
lines changed

8 files changed

+51
-38
lines changed

.github/workflows/ci.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,22 @@ jobs:
142142

143143
# test with a debug build as it picks up errors which optimised release builds do not
144144
test-debug:
145+
name: test-debug ${{ matrix.python-version }}
145146
runs-on: ubuntu-latest
146147

148+
strategy:
149+
fail-fast: false
150+
matrix:
151+
python-version:
152+
- '3.11'
153+
- 'pypy3.9'
154+
147155
steps:
148156
- uses: actions/checkout@v3
149157
- name: set up python
150158
uses: actions/setup-python@v4
151159
with:
152-
python-version: '3.11'
160+
python-version: ${{ matrix.python-version }}
153161

154162
- run: pip install -r tests/requirements.txt
155163
- run: pip install -e . --config-settings=build-args='--profile dev'

src/recursion_guard.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@ type RecursionKey = (
1414
#[derive(Debug, Clone, Default)]
1515
pub struct RecursionGuard {
1616
ids: Option<AHashSet<RecursionKey>>,
17-
// see validators/definition::BACKUP_GUARD_LIMIT for details
1817
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
1918
// use one number for all validators
2019
depth: u16,
2120
}
2221

22+
// A hard limit to avoid stack overflows when rampant recursion occurs
23+
const RECURSION_GUARD_LIMIT: u16 = if cfg!(target_family = "wasm") { 50 } else { 255 };
24+
2325
impl RecursionGuard {
2426
// insert a new id into the set, return whether the set already had the id in it
2527
pub fn contains_or_insert(&mut self, obj_id: usize, node_id: usize) -> bool {
@@ -37,9 +39,10 @@ impl RecursionGuard {
3739
}
3840

3941
// see #143 this is used as a backup in case the identity check recursion guard fails
40-
pub fn incr_depth(&mut self) -> u16 {
42+
#[must_use]
43+
pub fn incr_depth(&mut self) -> bool {
4144
self.depth += 1;
42-
self.depth
45+
self.depth >= RECURSION_GUARD_LIMIT
4346
}
4447

4548
pub fn decr_depth(&mut self) {

src/serializers/extra.rs

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ use pyo3::exceptions::PyValueError;
55
use pyo3::prelude::*;
66
use pyo3::{intern, AsPyPointer};
77

8-
use ahash::AHashSet;
98
use serde::ser::Error;
109

1110
use super::config::SerializationConfig;
1211
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
1312
use super::ob_type::ObTypeLookup;
1413
use super::shared::CombinedSerializer;
1514
use crate::definitions::Definitions;
15+
use crate::recursion_guard::RecursionGuard;
1616

1717
/// this is ugly, would be much better if extra could be stored in `SerializationState`
1818
/// then `SerializationState` got a `serialize_infer` method, but I couldn't get it to work
@@ -347,43 +347,31 @@ impl CollectWarnings {
347347
}
348348
}
349349

350-
/// we have `RecursionInfo` then a `RefCell` since `SerializeInfer.serialize` can't take a `&mut self`
351-
#[derive(Default, Clone)]
352-
#[cfg_attr(debug_assertions, derive(Debug))]
353-
pub struct RecursionInfo {
354-
ids: AHashSet<(usize, usize)>, // first element is the object's id, the second is the serializer's id
355-
/// as with `src/recursion_guard.rs` this is used as a backup in case the identity check recursion guard fails
356-
/// see #143
357-
depth: u16,
358-
}
359-
360350
#[derive(Default, Clone)]
361351
#[cfg_attr(debug_assertions, derive(Debug))]
362352
pub struct SerRecursionGuard {
363-
info: RefCell<RecursionInfo>,
353+
guard: RefCell<RecursionGuard>,
364354
}
365355

366356
impl SerRecursionGuard {
367-
const MAX_DEPTH: u16 = 200;
368-
369357
pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
370358
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
371359
// "If the set did not have this value present, `true` is returned."
372360
let id = value.as_ptr() as usize;
373-
let mut info = self.info.borrow_mut();
374-
if !info.ids.insert((id, def_ref_id)) {
361+
let mut guard = self.guard.borrow_mut();
362+
363+
if guard.contains_or_insert(id, def_ref_id) {
375364
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
376-
} else if info.depth > Self::MAX_DEPTH {
365+
} else if guard.incr_depth() {
377366
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
378367
} else {
379-
info.depth += 1;
380368
Ok(id)
381369
}
382370
}
383371

384372
pub fn pop(&self, id: usize, def_ref_id: usize) {
385-
let mut info = self.info.borrow_mut();
386-
info.depth -= 1;
387-
info.ids.remove(&(id, def_ref_id));
373+
let mut guard = self.guard.borrow_mut();
374+
guard.decr_depth();
375+
guard.remove(id, def_ref_id);
388376
}
389377
}

src/validators/definitions.rs

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ impl Validator for DefinitionRefValidator {
8484
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
8585
Err(ValError::new(ErrorType::RecursionLoop, input))
8686
} else {
87-
if recursion_guard.incr_depth() > BACKUP_GUARD_LIMIT {
87+
if recursion_guard.incr_depth() {
8888
return Err(ValError::new(ErrorType::RecursionLoop, input));
8989
}
9090
let output = validate(self.validator_id, py, input, extra, definitions, recursion_guard);
@@ -112,7 +112,7 @@ impl Validator for DefinitionRefValidator {
112112
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
113113
Err(ValError::new(ErrorType::RecursionLoop, obj))
114114
} else {
115-
if recursion_guard.incr_depth() > BACKUP_GUARD_LIMIT {
115+
if recursion_guard.incr_depth() {
116116
return Err(ValError::new(ErrorType::RecursionLoop, obj));
117117
}
118118
let output = validate_assignment(
@@ -169,15 +169,6 @@ impl Validator for DefinitionRefValidator {
169169
}
170170
}
171171

172-
// see #143 this is a backup in case the identity check recursion guard fails
173-
// if a single validator "depth" (how many times it's called inside itself) exceeds the limit,
174-
// we raise a recursion error.
175-
const BACKUP_GUARD_LIMIT: u16 = if cfg!(PyPy) || cfg!(target_family = "wasm") {
176-
123
177-
} else {
178-
255
179-
};
180-
181172
fn validate<'s, 'data>(
182173
validator_id: usize,
183174
py: Python<'data>,

tests/benchmarks/test_micro_benchmarks.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import pytest
1313
from dirty_equals import IsStr
1414

15+
import pydantic_core
1516
from pydantic_core import ArgsKwargs, PydanticCustomError, SchemaValidator, ValidationError, core_schema
1617
from pydantic_core import ValidationError as CoreValidationError
1718

@@ -1453,6 +1454,10 @@ def test_tagged_union_int_keys_json(benchmark):
14531454

14541455

14551456
@pytest.mark.benchmark(group='field_function_validator')
1457+
@pytest.mark.skipif(
1458+
platform.python_implementation() == 'PyPy' and pydantic_core._pydantic_core.build_profile == 'debug',
1459+
reason='PyPy does not have enough stack space for Rust debug builds to recurse very deep',
1460+
)
14561461
def test_field_function_validator(benchmark) -> None:
14571462
def f(v: int, info: core_schema.FieldValidationInfo) -> int:
14581463
assert info.field_name == 'x'

tests/serializers/test_any.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import dataclasses
22
import json
3+
import platform
34
import sys
45
from collections import namedtuple
56
from datetime import date, datetime, time, timedelta, timezone
@@ -11,6 +12,7 @@
1112
import pytest
1213
from dirty_equals import HasRepr, IsList
1314

15+
import pydantic_core
1416
from pydantic_core import PydanticSerializationError, SchemaSerializer, SchemaValidator, core_schema, to_json
1517

1618
from ..conftest import plain_repr
@@ -343,6 +345,10 @@ def __repr__(self):
343345
return f'<FoobarCount {self.v} repr>'
344346

345347

348+
@pytest.mark.skipif(
349+
platform.python_implementation() == 'PyPy' and pydantic_core._pydantic_core.build_profile == 'debug',
350+
reason='PyPy does not have enough stack space for Rust debug builds to recurse very deep',
351+
)
346352
def test_fallback_cycle_change(any_serializer: SchemaSerializer):
347353
v = 1
348354

@@ -360,7 +366,7 @@ def fallback_func(obj):
360366
f = FoobarCount(0)
361367
v = 0
362368
# because when recursion is detected and we're in mode python, we just return the value
363-
assert any_serializer.to_python(f, fallback=fallback_func) == HasRepr('<FoobarCount 201 repr>')
369+
assert any_serializer.to_python(f, fallback=fallback_func) == HasRepr('<FoobarCount 254 repr>')
364370

365371
with pytest.raises(ValueError, match=r'Circular reference detected \(depth exceeded\)'):
366372
any_serializer.to_json(f, fallback=fallback_func)

tests/test_json.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import json
2+
import platform
23
import re
34
from typing import List
45

56
import pytest
67
from dirty_equals import IsList
78

9+
import pydantic_core
810
from pydantic_core import (
911
PydanticSerializationError,
1012
SchemaSerializer,
@@ -285,6 +287,10 @@ def fallback_func_passthrough(obj):
285287
to_json(f, fallback=fallback_func_passthrough)
286288

287289

290+
@pytest.mark.skipif(
291+
platform.python_implementation() == 'PyPy' and pydantic_core._pydantic_core.build_profile == 'debug',
292+
reason='PyPy does not have enough stack space for Rust debug builds to recurse very deep',
293+
)
288294
def test_cycle_change():
289295
def fallback_func_change_id(obj):
290296
return Foobar()

tests/validators/test_definitions_recursive.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import platform
12
from dataclasses import dataclass
23
from typing import List, Optional
34

45
import pytest
56
from dirty_equals import AnyThing, HasAttributes, IsList, IsPartialDict, IsStr, IsTuple
67

8+
import pydantic_core
79
from pydantic_core import SchemaError, SchemaValidator, ValidationError, __version__, core_schema
810

911
from ..conftest import Err, plain_repr
@@ -714,6 +716,10 @@ def f(input_value, info):
714716
]
715717

716718

719+
@pytest.mark.skipif(
720+
platform.python_implementation() == 'PyPy' and pydantic_core._pydantic_core.build_profile == 'debug',
721+
reason='PyPy does not have enough stack space for Rust debug builds to recurse very deep',
722+
)
717723
@pytest.mark.parametrize('strict', [True, False], ids=lambda s: f'strict={s}')
718724
def test_function_change_id(strict: bool):
719725
def f(input_value, info):
@@ -750,7 +756,7 @@ def f(input_value, info):
750756

751757

752758
def test_many_uses_of_ref():
753-
# check we can safely exceed BACKUP_GUARD_LIMIT without upsetting the backup recursion guard
759+
# check we can safely exceed RECURSION_GUARD_LIMIT without upsetting the recursion guard
754760
v = SchemaValidator(
755761
{
756762
'type': 'typed-dict',

0 commit comments

Comments
 (0)