Skip to content

fix segfaults on tests in debug builds for PyPy #744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,22 @@ jobs:

# test with a debug build as it picks up errors which optimised release builds do not
test-debug:
name: test-debug ${{ matrix.python-version }}
runs-on: ubuntu-latest

strategy:
fail-fast: false
matrix:
python-version:
- '3.11'
- 'pypy3.9'

steps:
- uses: actions/checkout@v3
- name: set up python
uses: actions/setup-python@v4
with:
python-version: '3.11'
python-version: ${{ matrix.python-version }}

- run: pip install -r tests/requirements.txt
- run: pip install -e . --config-settings=build-args='--profile dev'
Expand Down
2 changes: 2 additions & 0 deletions python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ from _typeshed import SupportsAllComparisons
__all__ = [
'__version__',
'build_profile',
'_recursion_limit',
'ArgsKwargs',
'SchemaValidator',
'SchemaSerializer',
Expand All @@ -44,6 +45,7 @@ __all__ = [
]
__version__: str
build_profile: str
_recursion_limit: int

_T = TypeVar('_T', default=Any, covariant=True)

Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub fn get_version() -> String {
fn _pydantic_core(py: Python, m: &PyModule) -> PyResult<()> {
m.add("__version__", get_version())?;
m.add("build_profile", env!("PROFILE"))?;
m.add("_recursion_limit", recursion_guard::RECURSION_GUARD_LIMIT)?;
m.add("PydanticUndefined", PydanticUndefinedType::new(py))?;
m.add_class::<PydanticUndefinedType>()?;
m.add_class::<PySome>()?;
Expand Down
9 changes: 6 additions & 3 deletions src/recursion_guard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ type RecursionKey = (
#[derive(Debug, Clone, Default)]
pub struct RecursionGuard {
ids: Option<AHashSet<RecursionKey>>,
// see validators/definition::BACKUP_GUARD_LIMIT for details
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
// use one number for all validators
depth: u16,
}

// A hard limit to avoid stack overflows when rampant recursion occurs
pub const RECURSION_GUARD_LIMIT: u16 = if cfg!(target_family = "wasm") { 50 } else { 255 };

impl RecursionGuard {
// insert a new id into the set, return whether the set already had the id in it
pub fn contains_or_insert(&mut self, obj_id: usize, node_id: usize) -> bool {
Expand All @@ -37,9 +39,10 @@ impl RecursionGuard {
}

// see #143 this is used as a backup in case the identity check recursion guard fails
pub fn incr_depth(&mut self) -> u16 {
#[must_use]
pub fn incr_depth(&mut self) -> bool {
self.depth += 1;
self.depth
self.depth >= RECURSION_GUARD_LIMIT
}

pub fn decr_depth(&mut self) {
Expand Down
30 changes: 9 additions & 21 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::{intern, AsPyPointer};

use ahash::AHashSet;
use serde::ser::Error;

use super::config::SerializationConfig;
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
use super::ob_type::ObTypeLookup;
use super::shared::CombinedSerializer;
use crate::definitions::Definitions;
use crate::recursion_guard::RecursionGuard;

/// this is ugly, would be much better if extra could be stored in `SerializationState`
/// then `SerializationState` got a `serialize_infer` method, but I couldn't get it to work
Expand Down Expand Up @@ -347,43 +347,31 @@ impl CollectWarnings {
}
}

/// we have `RecursionInfo` then a `RefCell` since `SerializeInfer.serialize` can't take a `&mut self`
#[derive(Default, Clone)]
#[cfg_attr(debug_assertions, derive(Debug))]
pub struct RecursionInfo {
ids: AHashSet<(usize, usize)>, // first element is the object's id, the second is the serializer's id
/// as with `src/recursion_guard.rs` this is used as a backup in case the identity check recursion guard fails
/// see #143
depth: u16,
}

#[derive(Default, Clone)]
#[cfg_attr(debug_assertions, derive(Debug))]
pub struct SerRecursionGuard {
info: RefCell<RecursionInfo>,
guard: RefCell<RecursionGuard>,
}

impl SerRecursionGuard {
const MAX_DEPTH: u16 = 200;

pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
// "If the set did not have this value present, `true` is returned."
let id = value.as_ptr() as usize;
let mut info = self.info.borrow_mut();
if !info.ids.insert((id, def_ref_id)) {
let mut guard = self.guard.borrow_mut();

if guard.contains_or_insert(id, def_ref_id) {
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
} else if info.depth > Self::MAX_DEPTH {
} else if guard.incr_depth() {
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
} else {
info.depth += 1;
Ok(id)
}
}

pub fn pop(&self, id: usize, def_ref_id: usize) {
let mut info = self.info.borrow_mut();
info.depth -= 1;
info.ids.remove(&(id, def_ref_id));
let mut guard = self.guard.borrow_mut();
guard.decr_depth();
guard.remove(id, def_ref_id);
}
}
13 changes: 2 additions & 11 deletions src/validators/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl Validator for DefinitionRefValidator {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorType::RecursionLoop, input))
} else {
if recursion_guard.incr_depth() > BACKUP_GUARD_LIMIT {
if recursion_guard.incr_depth() {
return Err(ValError::new(ErrorType::RecursionLoop, input));
}
let output = validate(self.validator_id, py, input, extra, definitions, recursion_guard);
Expand Down Expand Up @@ -112,7 +112,7 @@ impl Validator for DefinitionRefValidator {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorType::RecursionLoop, obj))
} else {
if recursion_guard.incr_depth() > BACKUP_GUARD_LIMIT {
if recursion_guard.incr_depth() {
return Err(ValError::new(ErrorType::RecursionLoop, obj));
}
let output = validate_assignment(
Expand Down Expand Up @@ -169,15 +169,6 @@ impl Validator for DefinitionRefValidator {
}
}

// see #143 this is a backup in case the identity check recursion guard fails
// if a single validator "depth" (how many times it's called inside itself) exceeds the limit,
// we raise a recursion error.
const BACKUP_GUARD_LIMIT: u16 = if cfg!(PyPy) || cfg!(target_family = "wasm") {
123
} else {
255
};

fn validate<'s, 'data>(
validator_id: usize,
py: Python<'data>,
Expand Down
17 changes: 15 additions & 2 deletions tests/benchmarks/test_micro_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import os
import platform
import sys
from datetime import date, datetime, timedelta, timezone
from decimal import Decimal
from enum import Enum
Expand All @@ -12,6 +13,7 @@
import pytest
from dirty_equals import IsStr

import pydantic_core
from pydantic_core import ArgsKwargs, PydanticCustomError, SchemaValidator, ValidationError, core_schema
from pydantic_core import ValidationError as CoreValidationError

Expand All @@ -26,6 +28,15 @@

skip_pydantic = pytest.mark.skipif(BaseModel is None, reason='skipping benchmarks vs. pydantic')

skip_pypy_deep_stack = pytest.mark.skipif(
platform.python_implementation() == 'PyPy' and pydantic_core._pydantic_core.build_profile == 'debug',
reason='PyPy does not have enough stack space for Rust debug builds to recurse very deep',
)

skip_wasm_deep_stack = pytest.mark.skipif(
sys.platform == 'emscripten', reason='wasm does not have enough stack space to recurse very deep'
)


class TestBenchmarkSimpleModel:
@pytest.fixture(scope='class')
Expand Down Expand Up @@ -328,7 +339,6 @@ def definition_model_data():
return data


@pytest.mark.skipif(platform.python_implementation() == 'PyPy', reason='crashes on pypy due to recursion depth')
@skip_pydantic
@pytest.mark.benchmark(group='recursive model')
def test_definition_model_pyd(definition_model_data, benchmark):
Expand All @@ -339,7 +349,8 @@ class PydanticBranch(BaseModel):
benchmark(PydanticBranch.parse_obj, definition_model_data)


@pytest.mark.skipif(platform.python_implementation() == 'PyPy', reason='crashes on pypy due to recursion depth')
@skip_pypy_deep_stack
@skip_wasm_deep_stack
@pytest.mark.benchmark(group='recursive model')
def test_definition_model_core(definition_model_data, benchmark):
class CoreBranch:
Expand Down Expand Up @@ -1452,6 +1463,8 @@ def test_tagged_union_int_keys_json(benchmark):
benchmark(v.validate_json, payload)


@skip_pypy_deep_stack
@skip_wasm_deep_stack
@pytest.mark.benchmark(group='field_function_validator')
def test_field_function_validator(benchmark) -> None:
def f(v: int, info: core_schema.FieldValidationInfo) -> int:
Expand Down
11 changes: 9 additions & 2 deletions tests/serializers/test_any.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import json
import platform
import sys
from collections import namedtuple
from datetime import date, datetime, time, timedelta, timezone
Expand All @@ -11,6 +12,7 @@
import pytest
from dirty_equals import HasRepr, IsList

import pydantic_core
from pydantic_core import PydanticSerializationError, SchemaSerializer, SchemaValidator, core_schema, to_json

from ..conftest import plain_repr
Expand Down Expand Up @@ -343,6 +345,10 @@ def __repr__(self):
return f'<FoobarCount {self.v} repr>'


@pytest.mark.skipif(
platform.python_implementation() == 'PyPy' and pydantic_core._pydantic_core.build_profile == 'debug',
reason='PyPy does not have enough stack space for Rust debug builds to recurse very deep',
)
def test_fallback_cycle_change(any_serializer: SchemaSerializer):
v = 1

Expand All @@ -359,8 +365,9 @@ def fallback_func(obj):

f = FoobarCount(0)
v = 0
# because when recursion is detected and we're in mode python, we just return the value
assert any_serializer.to_python(f, fallback=fallback_func) == HasRepr('<FoobarCount 201 repr>')
# when recursion is detected and we're in mode python, we just return the value
expected_visits = pydantic_core._pydantic_core._recursion_limit - 1
assert any_serializer.to_python(f, fallback=fallback_func) == HasRepr(f'<FoobarCount {expected_visits} repr>')

with pytest.raises(ValueError, match=r'Circular reference detected \(depth exceeded\)'):
any_serializer.to_json(f, fallback=fallback_func)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_json.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
import platform
import re
from typing import List

import pytest
from dirty_equals import IsList

import pydantic_core
from pydantic_core import (
PydanticSerializationError,
SchemaSerializer,
Expand Down Expand Up @@ -285,6 +287,10 @@ def fallback_func_passthrough(obj):
to_json(f, fallback=fallback_func_passthrough)


@pytest.mark.skipif(
platform.python_implementation() == 'PyPy' and pydantic_core._pydantic_core.build_profile == 'debug',
reason='PyPy does not have enough stack space for Rust debug builds to recurse very deep',
)
def test_cycle_change():
def fallback_func_change_id(obj):
return Foobar()
Expand Down
8 changes: 7 additions & 1 deletion tests/validators/test_definitions_recursive.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import platform
from dataclasses import dataclass
from typing import List, Optional

import pytest
from dirty_equals import AnyThing, HasAttributes, IsList, IsPartialDict, IsStr, IsTuple

import pydantic_core
from pydantic_core import SchemaError, SchemaValidator, ValidationError, __version__, core_schema

from ..conftest import Err, plain_repr
Expand Down Expand Up @@ -714,6 +716,10 @@ def f(input_value, info):
]


@pytest.mark.skipif(
platform.python_implementation() == 'PyPy' and pydantic_core._pydantic_core.build_profile == 'debug',
reason='PyPy does not have enough stack space for Rust debug builds to recurse very deep',
)
@pytest.mark.parametrize('strict', [True, False], ids=lambda s: f'strict={s}')
def test_function_change_id(strict: bool):
def f(input_value, info):
Expand Down Expand Up @@ -750,7 +756,7 @@ def f(input_value, info):


def test_many_uses_of_ref():
# check we can safely exceed BACKUP_GUARD_LIMIT without upsetting the backup recursion guard
# check we can safely exceed RECURSION_GUARD_LIMIT without upsetting the recursion guard
v = SchemaValidator(
{
'type': 'typed-dict',
Expand Down