Skip to content

Commit eef45dc

Browse files
committed
ensure recursion guard is always used as a stack
1 parent 4da7192 commit eef45dc

File tree

9 files changed

+174
-97
lines changed

9 files changed

+174
-97
lines changed

src/recursion_guard.rs

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,72 @@ type RecursionKey = (
1212

1313
/// This is used to avoid cyclic references in input data causing recursive validation and a nasty segmentation fault.
1414
/// It's used in `validators/definition` to detect when a reference is reused within itself.
15+
pub(crate) struct RecursionGuard<'a, S: ContainsRecursionState> {
16+
state: &'a mut S,
17+
obj_id: usize,
18+
node_id: usize,
19+
incr_depth: bool,
20+
}
21+
22+
pub(crate) enum RecursionError {
23+
/// Cyclic reference detected
24+
Cyclic,
25+
/// Recursion limit exceeded
26+
Depth,
27+
}
28+
29+
impl<S: ContainsRecursionState> RecursionGuard<'_, S> {
30+
/// Creates a recursion guard for the given object and node id.
31+
///
32+
/// When dropped, this will release the recursion for the given object and node id.
33+
pub fn new(
34+
state: &'_ mut S,
35+
obj_id: usize,
36+
node_id: usize,
37+
incr_depth: bool,
38+
) -> Result<RecursionGuard<'_, S>, RecursionError> {
39+
state.access_recursion_state(|state| {
40+
if !state.insert(obj_id, node_id) {
41+
return Err(RecursionError::Cyclic);
42+
}
43+
if incr_depth && state.incr_depth() {
44+
return Err(RecursionError::Depth);
45+
}
46+
Ok(())
47+
})?;
48+
Ok(RecursionGuard {
49+
state,
50+
obj_id,
51+
node_id,
52+
incr_depth,
53+
})
54+
}
55+
56+
/// Retrieves the underlying state for further use.
57+
pub fn state(&mut self) -> &mut S {
58+
self.state
59+
}
60+
}
61+
62+
impl<S: ContainsRecursionState> Drop for RecursionGuard<'_, S> {
63+
fn drop(&mut self) {
64+
self.state.access_recursion_state(|state| {
65+
state.remove(self.obj_id, self.node_id);
66+
if self.incr_depth {
67+
state.decr_depth();
68+
}
69+
});
70+
}
71+
}
72+
73+
/// This trait is used to retrieve the recursion state from some other type
74+
pub(crate) trait ContainsRecursionState {
75+
fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R;
76+
}
77+
78+
/// State for the RecursionGuard. Can also be used directly to increase / decrease depth.
1579
#[derive(Debug, Clone, Default)]
16-
pub struct RecursionGuard {
80+
pub struct RecursionState {
1781
ids: RecursionStack,
1882
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
1983
// use one number for all validators
@@ -31,11 +95,11 @@ pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(wi
3195
255
3296
};
3397

34-
impl RecursionGuard {
98+
impl RecursionState {
3599
// insert a new value
36100
// * return `false` if the stack already had it in it
37101
// * return `true` if the stack didn't have it in it and it was inserted
38-
pub fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
102+
fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
39103
self.ids.insert((obj_id, node_id))
40104
}
41105

@@ -68,7 +132,7 @@ impl RecursionGuard {
68132
self.depth = self.depth.saturating_sub(1);
69133
}
70134

71-
pub fn remove(&mut self, obj_id: usize, node_id: usize) {
135+
fn remove(&mut self, obj_id: usize, node_id: usize) {
72136
self.ids.remove(&(obj_id, node_id));
73137
}
74138
}
@@ -98,7 +162,7 @@ impl RecursionStack {
98162
// insert a new value
99163
// * return `false` if the stack already had it in it
100164
// * return `true` if the stack didn't have it in it and it was inserted
101-
pub fn insert(&mut self, v: RecursionKey) -> bool {
165+
fn insert(&mut self, v: RecursionKey) -> bool {
102166
match self {
103167
Self::Array { data, len } => {
104168
if *len < ARRAY_SIZE {
@@ -129,7 +193,7 @@ impl RecursionStack {
129193
}
130194
}
131195

132-
pub fn remove(&mut self, v: &RecursionKey) {
196+
fn remove(&mut self, v: &RecursionKey) {
133197
match self {
134198
Self::Array { data, len } => {
135199
*len = len.checked_sub(1).expect("remove from empty recursion guard");

src/serializers/extra.rs

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,23 @@ use serde::ser::Error;
1010
use super::config::SerializationConfig;
1111
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
1212
use super::ob_type::ObTypeLookup;
13+
use crate::recursion_guard::ContainsRecursionState;
14+
use crate::recursion_guard::RecursionError;
1315
use crate::recursion_guard::RecursionGuard;
16+
use crate::recursion_guard::RecursionState;
1417

1518
/// this is ugly, would be much better if extra could be stored in `SerializationState`
1619
/// then `SerializationState` got a `serialize_infer` method, but I couldn't get it to work
1720
pub(crate) struct SerializationState {
1821
warnings: CollectWarnings,
19-
rec_guard: SerRecursionGuard,
22+
rec_guard: SerRecursionState,
2023
config: SerializationConfig,
2124
}
2225

2326
impl SerializationState {
2427
pub fn new(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult<Self> {
2528
let warnings = CollectWarnings::new(false);
26-
let rec_guard = SerRecursionGuard::default();
29+
let rec_guard = SerRecursionState::default();
2730
let config = SerializationConfig::from_args(timedelta_mode, bytes_mode, inf_nan_mode)?;
2831
Ok(Self {
2932
warnings,
@@ -77,7 +80,7 @@ pub(crate) struct Extra<'a> {
7780
pub exclude_none: bool,
7881
pub round_trip: bool,
7982
pub config: &'a SerializationConfig,
80-
pub rec_guard: &'a SerRecursionGuard,
83+
pub rec_guard: &'a SerRecursionState,
8184
// the next two are used for union logic
8285
pub check: SerCheck,
8386
// data representing the current model field
@@ -101,7 +104,7 @@ impl<'a> Extra<'a> {
101104
exclude_none: bool,
102105
round_trip: bool,
103106
config: &'a SerializationConfig,
104-
rec_guard: &'a SerRecursionGuard,
107+
rec_guard: &'a SerRecursionState,
105108
serialize_unknown: bool,
106109
fallback: Option<&'a PyAny>,
107110
) -> Self {
@@ -124,6 +127,23 @@ impl<'a> Extra<'a> {
124127
}
125128
}
126129

130+
pub fn recursion_guard<'x, 'y>(
131+
// TODO: this double reference is a bit if a hack, but it's necessary because the recursion
132+
// guard is not passed around with &mut reference
133+
//
134+
// See how validation has &mut ValidationState passed around; we should aim to refactor
135+
// to match that.
136+
self: &'x mut &'y Self,
137+
value: &PyAny,
138+
def_ref_id: usize,
139+
inc_depth: bool,
140+
) -> PyResult<RecursionGuard<'x, &'y Self>> {
141+
RecursionGuard::new(self, value.as_ptr() as usize, def_ref_id, inc_depth).map_err(|e| match e {
142+
RecursionError::Depth => PyValueError::new_err("Circular reference detected (depth exceeded)"),
143+
RecursionError::Cyclic => PyValueError::new_err("Circular reference detected (id repeated)"),
144+
})
145+
}
146+
127147
pub fn serialize_infer<'py>(&'py self, value: &'py PyAny) -> super::infer::SerializeInfer<'py> {
128148
super::infer::SerializeInfer::new(value, None, None, self)
129149
}
@@ -157,7 +177,7 @@ pub(crate) struct ExtraOwned {
157177
exclude_none: bool,
158178
round_trip: bool,
159179
config: SerializationConfig,
160-
rec_guard: SerRecursionGuard,
180+
rec_guard: SerRecursionState,
161181
check: SerCheck,
162182
model: Option<PyObject>,
163183
field_name: Option<String>,
@@ -340,29 +360,35 @@ impl CollectWarnings {
340360

341361
#[derive(Default, Clone)]
342362
#[cfg_attr(debug_assertions, derive(Debug))]
343-
pub struct SerRecursionGuard {
344-
guard: RefCell<RecursionGuard>,
363+
pub struct SerRecursionState {
364+
guard: RefCell<RecursionState>,
345365
}
346366

347-
impl SerRecursionGuard {
348-
pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
349-
let id = value.as_ptr() as usize;
350-
let mut guard = self.guard.borrow_mut();
351-
352-
if guard.insert(id, def_ref_id) {
353-
if guard.incr_depth() {
354-
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
355-
} else {
356-
Ok(id)
357-
}
358-
} else {
359-
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
360-
}
361-
}
362-
363-
pub fn pop(&self, id: usize, def_ref_id: usize) {
364-
let mut guard = self.guard.borrow_mut();
365-
guard.decr_depth();
366-
guard.remove(id, def_ref_id);
367+
impl ContainsRecursionState for &'_ Extra<'_> {
368+
fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R {
369+
f(&mut self.rec_guard.guard.borrow_mut())
367370
}
368371
}
372+
373+
// impl SerRecursionState {
374+
// pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
375+
// let id = value.as_ptr() as usize;
376+
// let mut guard = self.guard.borrow_mut();
377+
378+
// if guard.insert(id, def_ref_id) {
379+
// if guard.incr_depth() {
380+
// Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
381+
// } else {
382+
// Ok(id)
383+
// }
384+
// } else {
385+
// Err(PyValueError::new_err("Circular reference detected (id repeated)"))
386+
// }
387+
// }
388+
389+
// pub fn pop(&self, id: usize, def_ref_id: usize) {
390+
// let mut guard = self.guard.borrow_mut();
391+
// guard.decr_depth();
392+
// guard.remove(id, def_ref_id);
393+
// }
394+
// }

src/serializers/infer.rs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,22 @@ pub(crate) fn infer_to_python_known(
4040
value: &PyAny,
4141
include: Option<&PyAny>,
4242
exclude: Option<&PyAny>,
43-
extra: &Extra,
43+
mut extra: &Extra,
4444
) -> PyResult<PyObject> {
4545
let py = value.py();
46-
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID) {
47-
Ok(id) => id,
46+
47+
let mode = extra.mode;
48+
let mut guard = match extra.recursion_guard(value, INFER_DEF_REF_ID, true) {
49+
Ok(v) => v,
4850
Err(e) => {
49-
return match extra.mode {
51+
return match mode {
5052
SerMode::Json => Err(e),
5153
// if recursion is detected by we're serializing to python, we just return the value
5254
_ => Ok(value.into_py(py)),
5355
};
5456
}
5557
};
58+
let extra = guard.state();
5659

5760
macro_rules! serialize_seq {
5861
($t:ty) => {
@@ -220,7 +223,6 @@ pub(crate) fn infer_to_python_known(
220223
if let Some(fallback) = extra.fallback {
221224
let next_value = fallback.call1((value,))?;
222225
let next_result = infer_to_python(next_value, include, exclude, extra);
223-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
224226
return next_result;
225227
} else if extra.serialize_unknown {
226228
serialize_unknown(value).into_py(py)
@@ -267,15 +269,13 @@ pub(crate) fn infer_to_python_known(
267269
if let Some(fallback) = extra.fallback {
268270
let next_value = fallback.call1((value,))?;
269271
let next_result = infer_to_python(next_value, include, exclude, extra);
270-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
271272
return next_result;
272273
}
273274
value.into_py(py)
274275
}
275276
_ => value.into_py(py),
276277
},
277278
};
278-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
279279
Ok(value)
280280
}
281281

@@ -332,18 +332,21 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
332332
serializer: S,
333333
include: Option<&PyAny>,
334334
exclude: Option<&PyAny>,
335-
extra: &Extra,
335+
mut extra: &Extra,
336336
) -> Result<S::Ok, S::Error> {
337-
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) {
337+
let extra_serialize_unknown = extra.serialize_unknown;
338+
let mut guard = match extra.recursion_guard(value, INFER_DEF_REF_ID, true) {
338339
Ok(v) => v,
339340
Err(e) => {
340-
return if extra.serialize_unknown {
341+
return if extra_serialize_unknown {
341342
serializer.serialize_str("...")
342343
} else {
343-
Err(e)
344-
}
344+
Err(py_err_se_err(e))
345+
};
345346
}
346347
};
348+
let extra = guard.state();
349+
347350
macro_rules! serialize {
348351
($t:ty) => {
349352
match value.extract::<$t>() {
@@ -506,7 +509,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
506509
if let Some(fallback) = extra.fallback {
507510
let next_value = fallback.call1((value,)).map_err(py_err_se_err)?;
508511
let next_result = infer_serialize(next_value, serializer, include, exclude, extra);
509-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
510512
return next_result;
511513
} else if extra.serialize_unknown {
512514
serializer.serialize_str(&serialize_unknown(value))
@@ -520,7 +522,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
520522
}
521523
}
522524
};
523-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
524525
ser_result
525526
}
526527

src/serializers/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::py_gc::PyGcTraverse;
1010

1111
use config::SerializationConfig;
1212
pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
13-
use extra::{CollectWarnings, SerRecursionGuard};
13+
use extra::{CollectWarnings, SerRecursionState};
1414
pub(crate) use extra::{Extra, SerMode, SerializationState};
1515
pub use shared::CombinedSerializer;
1616
use shared::{to_json_bytes, BuildSerializer, TypeSerializer};
@@ -52,7 +52,7 @@ impl SchemaSerializer {
5252
exclude_defaults: bool,
5353
exclude_none: bool,
5454
round_trip: bool,
55-
rec_guard: &'a SerRecursionGuard,
55+
rec_guard: &'a SerRecursionState,
5656
serialize_unknown: bool,
5757
fallback: Option<&'a PyAny>,
5858
) -> Extra<'b> {
@@ -113,7 +113,7 @@ impl SchemaSerializer {
113113
) -> PyResult<PyObject> {
114114
let mode: SerMode = mode.into();
115115
let warnings = CollectWarnings::new(warnings);
116-
let rec_guard = SerRecursionGuard::default();
116+
let rec_guard = SerRecursionState::default();
117117
let extra = self.build_extra(
118118
py,
119119
&mode,
@@ -152,7 +152,7 @@ impl SchemaSerializer {
152152
fallback: Option<&PyAny>,
153153
) -> PyResult<PyObject> {
154154
let warnings = CollectWarnings::new(warnings);
155-
let rec_guard = SerRecursionGuard::default();
155+
let rec_guard = SerRecursionState::default();
156156
let extra = self.build_extra(
157157
py,
158158
&SerMode::Json,

0 commit comments

Comments
 (0)