Skip to content

ensure recursion guard is always used as a stack #1166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 57 additions & 6 deletions src/recursion_guard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,59 @@ type RecursionKey = (

/// This is used to avoid cyclic references in input data causing recursive validation and a nasty segmentation fault.
/// It's used in `validators/definition` to detect when a reference is reused within itself.
pub(crate) struct RecursionGuard<'a, S: ContainsRecursionState> {
state: &'a mut S,
obj_id: usize,
node_id: usize,
}

pub(crate) enum RecursionError {
/// Cyclic reference detected
Cyclic,
/// Recursion limit exceeded
Depth,
}

impl<S: ContainsRecursionState> RecursionGuard<'_, S> {
/// Creates a recursion guard for the given object and node id.
///
/// When dropped, this will release the recursion for the given object and node id.
pub fn new(state: &'_ mut S, obj_id: usize, node_id: usize) -> Result<RecursionGuard<'_, S>, RecursionError> {
state.access_recursion_state(|state| {
if !state.insert(obj_id, node_id) {
return Err(RecursionError::Cyclic);
}
if state.incr_depth() {
return Err(RecursionError::Depth);
}
Ok(())
})?;
Ok(RecursionGuard { state, obj_id, node_id })
}

/// Retrieves the underlying state for further use.
pub fn state(&mut self) -> &mut S {
self.state
}
}

impl<S: ContainsRecursionState> Drop for RecursionGuard<'_, S> {
fn drop(&mut self) {
self.state.access_recursion_state(|state| {
state.decr_depth();
state.remove(self.obj_id, self.node_id);
});
}
}

/// This trait is used to retrieve the recursion state from some other type
pub(crate) trait ContainsRecursionState {
fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R;
}

/// State for the RecursionGuard. Can also be used directly to increase / decrease depth.
#[derive(Debug, Clone, Default)]
pub struct RecursionGuard {
pub struct RecursionState {
ids: RecursionStack,
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
// use one number for all validators
Expand All @@ -31,11 +82,11 @@ pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(wi
255
};

impl RecursionGuard {
impl RecursionState {
// insert a new value
// * return `false` if the stack already had it in it
// * return `true` if the stack didn't have it in it and it was inserted
pub fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
self.ids.insert((obj_id, node_id))
}

Expand Down Expand Up @@ -68,7 +119,7 @@ impl RecursionGuard {
self.depth = self.depth.saturating_sub(1);
}

pub fn remove(&mut self, obj_id: usize, node_id: usize) {
fn remove(&mut self, obj_id: usize, node_id: usize) {
self.ids.remove(&(obj_id, node_id));
}
}
Expand Down Expand Up @@ -98,7 +149,7 @@ impl RecursionStack {
// insert a new value
// * return `false` if the stack already had it in it
// * return `true` if the stack didn't have it in it and it was inserted
pub fn insert(&mut self, v: RecursionKey) -> bool {
fn insert(&mut self, v: RecursionKey) -> bool {
match self {
Self::Array { data, len } => {
if *len < ARRAY_SIZE {
Expand Down Expand Up @@ -129,7 +180,7 @@ impl RecursionStack {
}
}

pub fn remove(&mut self, v: &RecursionKey) {
fn remove(&mut self, v: &RecursionKey) {
match self {
Self::Array { data, len } => {
*len = len.checked_sub(1).expect("remove from empty recursion guard");
Expand Down
56 changes: 29 additions & 27 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,23 @@ use serde::ser::Error;
use super::config::SerializationConfig;
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
use super::ob_type::ObTypeLookup;
use crate::recursion_guard::ContainsRecursionState;
use crate::recursion_guard::RecursionError;
use crate::recursion_guard::RecursionGuard;
use crate::recursion_guard::RecursionState;

/// this is ugly, would be much better if extra could be stored in `SerializationState`
/// then `SerializationState` got a `serialize_infer` method, but I couldn't get it to work
pub(crate) struct SerializationState {
warnings: CollectWarnings,
rec_guard: SerRecursionGuard,
rec_guard: SerRecursionState,
config: SerializationConfig,
}

impl SerializationState {
pub fn new(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult<Self> {
let warnings = CollectWarnings::new(false);
let rec_guard = SerRecursionGuard::default();
let rec_guard = SerRecursionState::default();
let config = SerializationConfig::from_args(timedelta_mode, bytes_mode, inf_nan_mode)?;
Ok(Self {
warnings,
Expand Down Expand Up @@ -77,7 +80,7 @@ pub(crate) struct Extra<'a> {
pub exclude_none: bool,
pub round_trip: bool,
pub config: &'a SerializationConfig,
pub rec_guard: &'a SerRecursionGuard,
pub rec_guard: &'a SerRecursionState,
// the next two are used for union logic
pub check: SerCheck,
// data representing the current model field
Expand All @@ -101,7 +104,7 @@ impl<'a> Extra<'a> {
exclude_none: bool,
round_trip: bool,
config: &'a SerializationConfig,
rec_guard: &'a SerRecursionGuard,
rec_guard: &'a SerRecursionState,
serialize_unknown: bool,
fallback: Option<&'a PyAny>,
) -> Self {
Expand All @@ -124,6 +127,22 @@ impl<'a> Extra<'a> {
}
}

pub fn recursion_guard<'x, 'y>(
// TODO: this double reference is a bit if a hack, but it's necessary because the recursion
// guard is not passed around with &mut reference
//
// See how validation has &mut ValidationState passed around; we should aim to refactor
// to match that.
self: &'x mut &'y Self,
value: &PyAny,
def_ref_id: usize,
) -> PyResult<RecursionGuard<'x, &'y Self>> {
RecursionGuard::new(self, value.as_ptr() as usize, def_ref_id).map_err(|e| match e {
RecursionError::Depth => PyValueError::new_err("Circular reference detected (depth exceeded)"),
RecursionError::Cyclic => PyValueError::new_err("Circular reference detected (id repeated)"),
})
}

pub fn serialize_infer<'py>(&'py self, value: &'py PyAny) -> super::infer::SerializeInfer<'py> {
super::infer::SerializeInfer::new(value, None, None, self)
}
Expand Down Expand Up @@ -157,7 +176,7 @@ pub(crate) struct ExtraOwned {
exclude_none: bool,
round_trip: bool,
config: SerializationConfig,
rec_guard: SerRecursionGuard,
rec_guard: SerRecursionState,
check: SerCheck,
model: Option<PyObject>,
field_name: Option<String>,
Expand Down Expand Up @@ -340,29 +359,12 @@ impl CollectWarnings {

#[derive(Default, Clone)]
#[cfg_attr(debug_assertions, derive(Debug))]
pub struct SerRecursionGuard {
guard: RefCell<RecursionGuard>,
pub struct SerRecursionState {
guard: RefCell<RecursionState>,
}

impl SerRecursionGuard {
pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
let id = value.as_ptr() as usize;
let mut guard = self.guard.borrow_mut();

if guard.insert(id, def_ref_id) {
if guard.incr_depth() {
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
} else {
Ok(id)
}
} else {
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
}
}

pub fn pop(&self, id: usize, def_ref_id: usize) {
let mut guard = self.guard.borrow_mut();
guard.decr_depth();
guard.remove(id, def_ref_id);
impl ContainsRecursionState for &'_ Extra<'_> {
fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R {
f(&mut self.rec_guard.guard.borrow_mut())
}
}
29 changes: 15 additions & 14 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,22 @@ pub(crate) fn infer_to_python_known(
value: &PyAny,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
mut extra: &Extra,
) -> PyResult<PyObject> {
let py = value.py();
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID) {
Ok(id) => id,

let mode = extra.mode;
let mut guard = match extra.recursion_guard(value, INFER_DEF_REF_ID) {
Ok(v) => v,
Err(e) => {
return match extra.mode {
return match mode {
SerMode::Json => Err(e),
// if recursion is detected by we're serializing to python, we just return the value
_ => Ok(value.into_py(py)),
};
}
};
let extra = guard.state();

macro_rules! serialize_seq {
($t:ty) => {
Expand Down Expand Up @@ -220,7 +223,6 @@ pub(crate) fn infer_to_python_known(
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,))?;
let next_result = infer_to_python(next_value, include, exclude, extra);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
return next_result;
} else if extra.serialize_unknown {
serialize_unknown(value).into_py(py)
Expand Down Expand Up @@ -267,15 +269,13 @@ pub(crate) fn infer_to_python_known(
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,))?;
let next_result = infer_to_python(next_value, include, exclude, extra);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
return next_result;
}
value.into_py(py)
}
_ => value.into_py(py),
},
};
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
Ok(value)
}

Expand Down Expand Up @@ -332,18 +332,21 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
serializer: S,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
mut extra: &Extra,
) -> Result<S::Ok, S::Error> {
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) {
let extra_serialize_unknown = extra.serialize_unknown;
let mut guard = match extra.recursion_guard(value, INFER_DEF_REF_ID) {
Ok(v) => v,
Err(e) => {
return if extra.serialize_unknown {
return if extra_serialize_unknown {
serializer.serialize_str("...")
} else {
Err(e)
}
Err(py_err_se_err(e))
};
}
};
let extra = guard.state();

macro_rules! serialize {
($t:ty) => {
match value.extract::<$t>() {
Expand Down Expand Up @@ -506,7 +509,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,)).map_err(py_err_se_err)?;
let next_result = infer_serialize(next_value, serializer, include, exclude, extra);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
return next_result;
} else if extra.serialize_unknown {
serializer.serialize_str(&serialize_unknown(value))
Expand All @@ -520,7 +522,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
}
}
};
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
ser_result
}

Expand Down
8 changes: 4 additions & 4 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::py_gc::PyGcTraverse;

use config::SerializationConfig;
pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
use extra::{CollectWarnings, SerRecursionGuard};
use extra::{CollectWarnings, SerRecursionState};
pub(crate) use extra::{Extra, SerMode, SerializationState};
pub use shared::CombinedSerializer;
use shared::{to_json_bytes, BuildSerializer, TypeSerializer};
Expand Down Expand Up @@ -52,7 +52,7 @@ impl SchemaSerializer {
exclude_defaults: bool,
exclude_none: bool,
round_trip: bool,
rec_guard: &'a SerRecursionGuard,
rec_guard: &'a SerRecursionState,
serialize_unknown: bool,
fallback: Option<&'a PyAny>,
) -> Extra<'b> {
Expand Down Expand Up @@ -113,7 +113,7 @@ impl SchemaSerializer {
) -> PyResult<PyObject> {
let mode: SerMode = mode.into();
let warnings = CollectWarnings::new(warnings);
let rec_guard = SerRecursionGuard::default();
let rec_guard = SerRecursionState::default();
let extra = self.build_extra(
py,
&mode,
Expand Down Expand Up @@ -152,7 +152,7 @@ impl SchemaSerializer {
fallback: Option<&PyAny>,
) -> PyResult<PyObject> {
let warnings = CollectWarnings::new(warnings);
let rec_guard = SerRecursionGuard::default();
let rec_guard = SerRecursionState::default();
let extra = self.build_extra(
py,
&SerMode::Json,
Expand Down
19 changes: 7 additions & 12 deletions src/serializers/type_serializers/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,12 @@ impl TypeSerializer for DefinitionRefSerializer {
value: &PyAny,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
mut extra: &Extra,
) -> PyResult<PyObject> {
self.definition.read(|comb_serializer| {
let comb_serializer = comb_serializer.unwrap();
let value_id = extra.rec_guard.add(value, self.definition.id())?;
let r = comb_serializer.to_python(value, include, exclude, extra);
extra.rec_guard.pop(value_id, self.definition.id());
r
let mut guard = extra.recursion_guard(value, self.definition.id())?;
comb_serializer.to_python(value, include, exclude, guard.state())
})
}

Expand All @@ -87,17 +85,14 @@ impl TypeSerializer for DefinitionRefSerializer {
serializer: S,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
mut extra: &Extra,
) -> Result<S::Ok, S::Error> {
self.definition.read(|comb_serializer| {
let comb_serializer = comb_serializer.unwrap();
let value_id = extra
.rec_guard
.add(value, self.definition.id())
let mut guard = extra
.recursion_guard(value, self.definition.id())
.map_err(py_err_se_err)?;
let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra);
extra.rec_guard.pop(value_id, self.definition.id());
r
comb_serializer.serde_serialize(value, serializer, include, exclude, guard.state())
})
}

Expand Down
Loading