Skip to content

Commit a47f5ef

Browse files
committed
ensure recursion guard is always used as a stack
1 parent 4da7192 commit a47f5ef

File tree

9 files changed

+166
-104
lines changed

9 files changed

+166
-104
lines changed

src/recursion_guard.rs

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,72 @@ type RecursionKey = (
1212

1313
/// This is used to avoid cyclic references in input data causing recursive validation and a nasty segmentation fault.
1414
/// It's used in `validators/definition` to detect when a reference is reused within itself.
15+
pub(crate) struct RecursionGuard<'a, S: ContainsRecursionState> {
16+
state: &'a mut S,
17+
obj_id: usize,
18+
node_id: usize,
19+
incr_depth: bool,
20+
}
21+
22+
pub(crate) enum RecursionError {
23+
/// Cyclic reference detected
24+
Cyclic,
25+
/// Recursion limit exceeded
26+
Depth,
27+
}
28+
29+
impl<S: ContainsRecursionState> RecursionGuard<'_, S> {
30+
/// Creates a recursion guard for the given object and node id.
31+
///
32+
/// When dropped, this will release the recursion for the given object and node id.
33+
pub fn new(
34+
state: &'_ mut S,
35+
obj_id: usize,
36+
node_id: usize,
37+
incr_depth: bool,
38+
) -> Result<RecursionGuard<'_, S>, RecursionError> {
39+
state.access_recursion_state(|state| {
40+
if !state.insert(obj_id, node_id) {
41+
return Err(RecursionError::Cyclic);
42+
}
43+
if incr_depth && state.incr_depth() {
44+
return Err(RecursionError::Depth);
45+
}
46+
Ok(())
47+
})?;
48+
Ok(RecursionGuard {
49+
state,
50+
obj_id,
51+
node_id,
52+
incr_depth,
53+
})
54+
}
55+
56+
/// Retrieves the underlying state for further use.
57+
pub fn state(&mut self) -> &mut S {
58+
self.state
59+
}
60+
}
61+
62+
impl<S: ContainsRecursionState> Drop for RecursionGuard<'_, S> {
63+
fn drop(&mut self) {
64+
self.state.access_recursion_state(|state| {
65+
state.remove(self.obj_id, self.node_id);
66+
if self.incr_depth {
67+
state.decr_depth();
68+
}
69+
});
70+
}
71+
}
72+
73+
/// This trait is used to retrieve the recursion state from some other type
74+
pub(crate) trait ContainsRecursionState {
75+
fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R;
76+
}
77+
78+
/// State for the RecursionGuard. Can also be used directly to increase / decrease depth.
1579
#[derive(Debug, Clone, Default)]
16-
pub struct RecursionGuard {
80+
pub struct RecursionState {
1781
ids: RecursionStack,
1882
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
1983
// use one number for all validators
@@ -31,11 +95,11 @@ pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(wi
3195
255
3296
};
3397

34-
impl RecursionGuard {
98+
impl RecursionState {
3599
// insert a new value
36100
// * return `false` if the stack already had it in it
37101
// * return `true` if the stack didn't have it in it and it was inserted
38-
pub fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
102+
fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
39103
self.ids.insert((obj_id, node_id))
40104
}
41105

@@ -68,7 +132,7 @@ impl RecursionGuard {
68132
self.depth = self.depth.saturating_sub(1);
69133
}
70134

71-
pub fn remove(&mut self, obj_id: usize, node_id: usize) {
135+
fn remove(&mut self, obj_id: usize, node_id: usize) {
72136
self.ids.remove(&(obj_id, node_id));
73137
}
74138
}
@@ -98,7 +162,7 @@ impl RecursionStack {
98162
// insert a new value
99163
// * return `false` if the stack already had it in it
100164
// * return `true` if the stack didn't have it in it and it was inserted
101-
pub fn insert(&mut self, v: RecursionKey) -> bool {
165+
fn insert(&mut self, v: RecursionKey) -> bool {
102166
match self {
103167
Self::Array { data, len } => {
104168
if *len < ARRAY_SIZE {
@@ -129,7 +193,7 @@ impl RecursionStack {
129193
}
130194
}
131195

132-
pub fn remove(&mut self, v: &RecursionKey) {
196+
fn remove(&mut self, v: &RecursionKey) {
133197
match self {
134198
Self::Array { data, len } => {
135199
*len = len.checked_sub(1).expect("remove from empty recursion guard");

src/serializers/extra.rs

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,23 @@ use serde::ser::Error;
1010
use super::config::SerializationConfig;
1111
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
1212
use super::ob_type::ObTypeLookup;
13+
use crate::recursion_guard::ContainsRecursionState;
14+
use crate::recursion_guard::RecursionError;
1315
use crate::recursion_guard::RecursionGuard;
16+
use crate::recursion_guard::RecursionState;
1417

1518
/// this is ugly, would be much better if extra could be stored in `SerializationState`
1619
/// then `SerializationState` got a `serialize_infer` method, but I couldn't get it to work
1720
pub(crate) struct SerializationState {
1821
warnings: CollectWarnings,
19-
rec_guard: SerRecursionGuard,
22+
rec_guard: SerRecursionState,
2023
config: SerializationConfig,
2124
}
2225

2326
impl SerializationState {
2427
pub fn new(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult<Self> {
2528
let warnings = CollectWarnings::new(false);
26-
let rec_guard = SerRecursionGuard::default();
29+
let rec_guard = SerRecursionState::default();
2730
let config = SerializationConfig::from_args(timedelta_mode, bytes_mode, inf_nan_mode)?;
2831
Ok(Self {
2932
warnings,
@@ -77,7 +80,7 @@ pub(crate) struct Extra<'a> {
7780
pub exclude_none: bool,
7881
pub round_trip: bool,
7982
pub config: &'a SerializationConfig,
80-
pub rec_guard: &'a SerRecursionGuard,
83+
pub rec_guard: &'a SerRecursionState,
8184
// the next two are used for union logic
8285
pub check: SerCheck,
8386
// data representing the current model field
@@ -101,7 +104,7 @@ impl<'a> Extra<'a> {
101104
exclude_none: bool,
102105
round_trip: bool,
103106
config: &'a SerializationConfig,
104-
rec_guard: &'a SerRecursionGuard,
107+
rec_guard: &'a SerRecursionState,
105108
serialize_unknown: bool,
106109
fallback: Option<&'a PyAny>,
107110
) -> Self {
@@ -124,6 +127,18 @@ impl<'a> Extra<'a> {
124127
}
125128
}
126129

130+
pub fn recursion_guard<'x, 'y>(
131+
self: &'x mut &'y Self,
132+
value: &PyAny,
133+
def_ref_id: usize,
134+
inc_depth: bool,
135+
) -> PyResult<RecursionGuard<'x, &'y Self>> {
136+
RecursionGuard::new(self, value.as_ptr() as usize, def_ref_id, inc_depth).map_err(|e| match e {
137+
RecursionError::Depth => PyValueError::new_err("Circular reference detected (depth exceeded)"),
138+
RecursionError::Cyclic => PyValueError::new_err("Circular reference detected (id repeated)"),
139+
})
140+
}
141+
127142
pub fn serialize_infer<'py>(&'py self, value: &'py PyAny) -> super::infer::SerializeInfer<'py> {
128143
super::infer::SerializeInfer::new(value, None, None, self)
129144
}
@@ -157,7 +172,7 @@ pub(crate) struct ExtraOwned {
157172
exclude_none: bool,
158173
round_trip: bool,
159174
config: SerializationConfig,
160-
rec_guard: SerRecursionGuard,
175+
rec_guard: SerRecursionState,
161176
check: SerCheck,
162177
model: Option<PyObject>,
163178
field_name: Option<String>,
@@ -340,29 +355,35 @@ impl CollectWarnings {
340355

341356
#[derive(Default, Clone)]
342357
#[cfg_attr(debug_assertions, derive(Debug))]
343-
pub struct SerRecursionGuard {
344-
guard: RefCell<RecursionGuard>,
358+
pub struct SerRecursionState {
359+
guard: RefCell<RecursionState>,
345360
}
346361

347-
impl SerRecursionGuard {
348-
pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
349-
let id = value.as_ptr() as usize;
350-
let mut guard = self.guard.borrow_mut();
351-
352-
if guard.insert(id, def_ref_id) {
353-
if guard.incr_depth() {
354-
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
355-
} else {
356-
Ok(id)
357-
}
358-
} else {
359-
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
360-
}
361-
}
362-
363-
pub fn pop(&self, id: usize, def_ref_id: usize) {
364-
let mut guard = self.guard.borrow_mut();
365-
guard.decr_depth();
366-
guard.remove(id, def_ref_id);
362+
impl ContainsRecursionState for &'_ Extra<'_> {
363+
fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R {
364+
f(&mut self.rec_guard.guard.borrow_mut())
367365
}
368366
}
367+
368+
// impl SerRecursionState {
369+
// pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
370+
// let id = value.as_ptr() as usize;
371+
// let mut guard = self.guard.borrow_mut();
372+
373+
// if guard.insert(id, def_ref_id) {
374+
// if guard.incr_depth() {
375+
// Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
376+
// } else {
377+
// Ok(id)
378+
// }
379+
// } else {
380+
// Err(PyValueError::new_err("Circular reference detected (id repeated)"))
381+
// }
382+
// }
383+
384+
// pub fn pop(&self, id: usize, def_ref_id: usize) {
385+
// let mut guard = self.guard.borrow_mut();
386+
// guard.decr_depth();
387+
// guard.remove(id, def_ref_id);
388+
// }
389+
// }

src/serializers/infer.rs

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,12 @@ pub(crate) fn infer_to_python_known(
4040
value: &PyAny,
4141
include: Option<&PyAny>,
4242
exclude: Option<&PyAny>,
43-
extra: &Extra,
43+
mut extra: &Extra,
4444
) -> PyResult<PyObject> {
4545
let py = value.py();
46-
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID) {
47-
Ok(id) => id,
48-
Err(e) => {
49-
return match extra.mode {
50-
SerMode::Json => Err(e),
51-
// if recursion is detected by we're serializing to python, we just return the value
52-
_ => Ok(value.into_py(py)),
53-
};
54-
}
55-
};
46+
47+
let mut guard = extra.recursion_guard(value, INFER_DEF_REF_ID, true)?;
48+
let extra = guard.state();
5649

5750
macro_rules! serialize_seq {
5851
($t:ty) => {
@@ -220,7 +213,6 @@ pub(crate) fn infer_to_python_known(
220213
if let Some(fallback) = extra.fallback {
221214
let next_value = fallback.call1((value,))?;
222215
let next_result = infer_to_python(next_value, include, exclude, extra);
223-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
224216
return next_result;
225217
} else if extra.serialize_unknown {
226218
serialize_unknown(value).into_py(py)
@@ -267,15 +259,13 @@ pub(crate) fn infer_to_python_known(
267259
if let Some(fallback) = extra.fallback {
268260
let next_value = fallback.call1((value,))?;
269261
let next_result = infer_to_python(next_value, include, exclude, extra);
270-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
271262
return next_result;
272263
}
273264
value.into_py(py)
274265
}
275266
_ => value.into_py(py),
276267
},
277268
};
278-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
279269
Ok(value)
280270
}
281271

@@ -332,18 +322,21 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
332322
serializer: S,
333323
include: Option<&PyAny>,
334324
exclude: Option<&PyAny>,
335-
extra: &Extra,
325+
mut extra: &Extra,
336326
) -> Result<S::Ok, S::Error> {
337-
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) {
327+
let extra_serialize_unknown = extra.serialize_unknown;
328+
let mut guard = match extra.recursion_guard(value, INFER_DEF_REF_ID, true) {
338329
Ok(v) => v,
339330
Err(e) => {
340-
return if extra.serialize_unknown {
331+
return if extra_serialize_unknown {
341332
serializer.serialize_str("...")
342333
} else {
343-
Err(e)
344-
}
334+
Err(py_err_se_err(e))
335+
};
345336
}
346337
};
338+
let extra = guard.state();
339+
347340
macro_rules! serialize {
348341
($t:ty) => {
349342
match value.extract::<$t>() {
@@ -506,7 +499,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
506499
if let Some(fallback) = extra.fallback {
507500
let next_value = fallback.call1((value,)).map_err(py_err_se_err)?;
508501
let next_result = infer_serialize(next_value, serializer, include, exclude, extra);
509-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
510502
return next_result;
511503
} else if extra.serialize_unknown {
512504
serializer.serialize_str(&serialize_unknown(value))
@@ -520,7 +512,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
520512
}
521513
}
522514
};
523-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
524515
ser_result
525516
}
526517

src/serializers/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::py_gc::PyGcTraverse;
1010

1111
use config::SerializationConfig;
1212
pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
13-
use extra::{CollectWarnings, SerRecursionGuard};
13+
use extra::{CollectWarnings, SerRecursionState};
1414
pub(crate) use extra::{Extra, SerMode, SerializationState};
1515
pub use shared::CombinedSerializer;
1616
use shared::{to_json_bytes, BuildSerializer, TypeSerializer};
@@ -52,7 +52,7 @@ impl SchemaSerializer {
5252
exclude_defaults: bool,
5353
exclude_none: bool,
5454
round_trip: bool,
55-
rec_guard: &'a SerRecursionGuard,
55+
rec_guard: &'a SerRecursionState,
5656
serialize_unknown: bool,
5757
fallback: Option<&'a PyAny>,
5858
) -> Extra<'b> {
@@ -113,7 +113,7 @@ impl SchemaSerializer {
113113
) -> PyResult<PyObject> {
114114
let mode: SerMode = mode.into();
115115
let warnings = CollectWarnings::new(warnings);
116-
let rec_guard = SerRecursionGuard::default();
116+
let rec_guard = SerRecursionState::default();
117117
let extra = self.build_extra(
118118
py,
119119
&mode,
@@ -152,7 +152,7 @@ impl SchemaSerializer {
152152
fallback: Option<&PyAny>,
153153
) -> PyResult<PyObject> {
154154
let warnings = CollectWarnings::new(warnings);
155-
let rec_guard = SerRecursionGuard::default();
155+
let rec_guard = SerRecursionState::default();
156156
let extra = self.build_extra(
157157
py,
158158
&SerMode::Json,

0 commit comments

Comments
 (0)