Skip to content

Commit 7a5f8e6

Browse files
authored
Ensure recursion guard is always used as a stack (#1166)
1 parent 4da7192 commit 7a5f8e6

File tree

9 files changed

+137
-97
lines changed

9 files changed

+137
-97
lines changed

src/recursion_guard.rs

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,59 @@ type RecursionKey = (
1212

1313
/// This is used to avoid cyclic references in input data causing recursive validation and a nasty segmentation fault.
1414
/// It's used in `validators/definition` to detect when a reference is reused within itself.
15+
pub(crate) struct RecursionGuard<'a, S: ContainsRecursionState> {
16+
state: &'a mut S,
17+
obj_id: usize,
18+
node_id: usize,
19+
}
20+
21+
pub(crate) enum RecursionError {
22+
/// Cyclic reference detected
23+
Cyclic,
24+
/// Recursion limit exceeded
25+
Depth,
26+
}
27+
28+
impl<S: ContainsRecursionState> RecursionGuard<'_, S> {
29+
/// Creates a recursion guard for the given object and node id.
30+
///
31+
/// When dropped, this will release the recursion for the given object and node id.
32+
pub fn new(state: &'_ mut S, obj_id: usize, node_id: usize) -> Result<RecursionGuard<'_, S>, RecursionError> {
33+
state.access_recursion_state(|state| {
34+
if !state.insert(obj_id, node_id) {
35+
return Err(RecursionError::Cyclic);
36+
}
37+
if state.incr_depth() {
38+
return Err(RecursionError::Depth);
39+
}
40+
Ok(())
41+
})?;
42+
Ok(RecursionGuard { state, obj_id, node_id })
43+
}
44+
45+
/// Retrieves the underlying state for further use.
46+
pub fn state(&mut self) -> &mut S {
47+
self.state
48+
}
49+
}
50+
51+
impl<S: ContainsRecursionState> Drop for RecursionGuard<'_, S> {
52+
fn drop(&mut self) {
53+
self.state.access_recursion_state(|state| {
54+
state.decr_depth();
55+
state.remove(self.obj_id, self.node_id);
56+
});
57+
}
58+
}
59+
60+
/// This trait is used to retrieve the recursion state from some other type
61+
pub(crate) trait ContainsRecursionState {
62+
fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R;
63+
}
64+
65+
/// State for the RecursionGuard. Can also be used directly to increase / decrease depth.
1566
#[derive(Debug, Clone, Default)]
16-
pub struct RecursionGuard {
67+
pub struct RecursionState {
1768
ids: RecursionStack,
1869
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
1970
// use one number for all validators
@@ -31,11 +82,11 @@ pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(wi
3182
255
3283
};
3384

34-
impl RecursionGuard {
85+
impl RecursionState {
3586
// insert a new value
3687
// * return `false` if the stack already had it in it
3788
// * return `true` if the stack didn't have it in it and it was inserted
38-
pub fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
89+
fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
3990
self.ids.insert((obj_id, node_id))
4091
}
4192

@@ -68,7 +119,7 @@ impl RecursionGuard {
68119
self.depth = self.depth.saturating_sub(1);
69120
}
70121

71-
pub fn remove(&mut self, obj_id: usize, node_id: usize) {
122+
fn remove(&mut self, obj_id: usize, node_id: usize) {
72123
self.ids.remove(&(obj_id, node_id));
73124
}
74125
}
@@ -98,7 +149,7 @@ impl RecursionStack {
98149
// insert a new value
99150
// * return `false` if the stack already had it in it
100151
// * return `true` if the stack didn't have it in it and it was inserted
101-
pub fn insert(&mut self, v: RecursionKey) -> bool {
152+
fn insert(&mut self, v: RecursionKey) -> bool {
102153
match self {
103154
Self::Array { data, len } => {
104155
if *len < ARRAY_SIZE {
@@ -129,7 +180,7 @@ impl RecursionStack {
129180
}
130181
}
131182

132-
pub fn remove(&mut self, v: &RecursionKey) {
183+
fn remove(&mut self, v: &RecursionKey) {
133184
match self {
134185
Self::Array { data, len } => {
135186
*len = len.checked_sub(1).expect("remove from empty recursion guard");

src/serializers/extra.rs

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,23 @@ use serde::ser::Error;
1010
use super::config::SerializationConfig;
1111
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
1212
use super::ob_type::ObTypeLookup;
13+
use crate::recursion_guard::ContainsRecursionState;
14+
use crate::recursion_guard::RecursionError;
1315
use crate::recursion_guard::RecursionGuard;
16+
use crate::recursion_guard::RecursionState;
1417

1518
/// this is ugly, would be much better if extra could be stored in `SerializationState`
1619
/// then `SerializationState` got a `serialize_infer` method, but I couldn't get it to work
1720
pub(crate) struct SerializationState {
1821
warnings: CollectWarnings,
19-
rec_guard: SerRecursionGuard,
22+
rec_guard: SerRecursionState,
2023
config: SerializationConfig,
2124
}
2225

2326
impl SerializationState {
2427
pub fn new(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult<Self> {
2528
let warnings = CollectWarnings::new(false);
26-
let rec_guard = SerRecursionGuard::default();
29+
let rec_guard = SerRecursionState::default();
2730
let config = SerializationConfig::from_args(timedelta_mode, bytes_mode, inf_nan_mode)?;
2831
Ok(Self {
2932
warnings,
@@ -77,7 +80,7 @@ pub(crate) struct Extra<'a> {
7780
pub exclude_none: bool,
7881
pub round_trip: bool,
7982
pub config: &'a SerializationConfig,
80-
pub rec_guard: &'a SerRecursionGuard,
83+
pub rec_guard: &'a SerRecursionState,
8184
// the next two are used for union logic
8285
pub check: SerCheck,
8386
// data representing the current model field
@@ -101,7 +104,7 @@ impl<'a> Extra<'a> {
101104
exclude_none: bool,
102105
round_trip: bool,
103106
config: &'a SerializationConfig,
104-
rec_guard: &'a SerRecursionGuard,
107+
rec_guard: &'a SerRecursionState,
105108
serialize_unknown: bool,
106109
fallback: Option<&'a PyAny>,
107110
) -> Self {
@@ -124,6 +127,22 @@ impl<'a> Extra<'a> {
124127
}
125128
}
126129

130+
pub fn recursion_guard<'x, 'y>(
131+
// TODO: this double reference is a bit if a hack, but it's necessary because the recursion
132+
// guard is not passed around with &mut reference
133+
//
134+
// See how validation has &mut ValidationState passed around; we should aim to refactor
135+
// to match that.
136+
self: &'x mut &'y Self,
137+
value: &PyAny,
138+
def_ref_id: usize,
139+
) -> PyResult<RecursionGuard<'x, &'y Self>> {
140+
RecursionGuard::new(self, value.as_ptr() as usize, def_ref_id).map_err(|e| match e {
141+
RecursionError::Depth => PyValueError::new_err("Circular reference detected (depth exceeded)"),
142+
RecursionError::Cyclic => PyValueError::new_err("Circular reference detected (id repeated)"),
143+
})
144+
}
145+
127146
pub fn serialize_infer<'py>(&'py self, value: &'py PyAny) -> super::infer::SerializeInfer<'py> {
128147
super::infer::SerializeInfer::new(value, None, None, self)
129148
}
@@ -157,7 +176,7 @@ pub(crate) struct ExtraOwned {
157176
exclude_none: bool,
158177
round_trip: bool,
159178
config: SerializationConfig,
160-
rec_guard: SerRecursionGuard,
179+
rec_guard: SerRecursionState,
161180
check: SerCheck,
162181
model: Option<PyObject>,
163182
field_name: Option<String>,
@@ -340,29 +359,12 @@ impl CollectWarnings {
340359

341360
#[derive(Default, Clone)]
342361
#[cfg_attr(debug_assertions, derive(Debug))]
343-
pub struct SerRecursionGuard {
344-
guard: RefCell<RecursionGuard>,
362+
pub struct SerRecursionState {
363+
guard: RefCell<RecursionState>,
345364
}
346365

347-
impl SerRecursionGuard {
348-
pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
349-
let id = value.as_ptr() as usize;
350-
let mut guard = self.guard.borrow_mut();
351-
352-
if guard.insert(id, def_ref_id) {
353-
if guard.incr_depth() {
354-
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
355-
} else {
356-
Ok(id)
357-
}
358-
} else {
359-
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
360-
}
361-
}
362-
363-
pub fn pop(&self, id: usize, def_ref_id: usize) {
364-
let mut guard = self.guard.borrow_mut();
365-
guard.decr_depth();
366-
guard.remove(id, def_ref_id);
366+
impl ContainsRecursionState for &'_ Extra<'_> {
367+
fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R {
368+
f(&mut self.rec_guard.guard.borrow_mut())
367369
}
368370
}

src/serializers/infer.rs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,22 @@ pub(crate) fn infer_to_python_known(
4040
value: &PyAny,
4141
include: Option<&PyAny>,
4242
exclude: Option<&PyAny>,
43-
extra: &Extra,
43+
mut extra: &Extra,
4444
) -> PyResult<PyObject> {
4545
let py = value.py();
46-
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID) {
47-
Ok(id) => id,
46+
47+
let mode = extra.mode;
48+
let mut guard = match extra.recursion_guard(value, INFER_DEF_REF_ID) {
49+
Ok(v) => v,
4850
Err(e) => {
49-
return match extra.mode {
51+
return match mode {
5052
SerMode::Json => Err(e),
5153
// if recursion is detected by we're serializing to python, we just return the value
5254
_ => Ok(value.into_py(py)),
5355
};
5456
}
5557
};
58+
let extra = guard.state();
5659

5760
macro_rules! serialize_seq {
5861
($t:ty) => {
@@ -220,7 +223,6 @@ pub(crate) fn infer_to_python_known(
220223
if let Some(fallback) = extra.fallback {
221224
let next_value = fallback.call1((value,))?;
222225
let next_result = infer_to_python(next_value, include, exclude, extra);
223-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
224226
return next_result;
225227
} else if extra.serialize_unknown {
226228
serialize_unknown(value).into_py(py)
@@ -267,15 +269,13 @@ pub(crate) fn infer_to_python_known(
267269
if let Some(fallback) = extra.fallback {
268270
let next_value = fallback.call1((value,))?;
269271
let next_result = infer_to_python(next_value, include, exclude, extra);
270-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
271272
return next_result;
272273
}
273274
value.into_py(py)
274275
}
275276
_ => value.into_py(py),
276277
},
277278
};
278-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
279279
Ok(value)
280280
}
281281

@@ -332,18 +332,21 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
332332
serializer: S,
333333
include: Option<&PyAny>,
334334
exclude: Option<&PyAny>,
335-
extra: &Extra,
335+
mut extra: &Extra,
336336
) -> Result<S::Ok, S::Error> {
337-
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) {
337+
let extra_serialize_unknown = extra.serialize_unknown;
338+
let mut guard = match extra.recursion_guard(value, INFER_DEF_REF_ID) {
338339
Ok(v) => v,
339340
Err(e) => {
340-
return if extra.serialize_unknown {
341+
return if extra_serialize_unknown {
341342
serializer.serialize_str("...")
342343
} else {
343-
Err(e)
344-
}
344+
Err(py_err_se_err(e))
345+
};
345346
}
346347
};
348+
let extra = guard.state();
349+
347350
macro_rules! serialize {
348351
($t:ty) => {
349352
match value.extract::<$t>() {
@@ -506,7 +509,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
506509
if let Some(fallback) = extra.fallback {
507510
let next_value = fallback.call1((value,)).map_err(py_err_se_err)?;
508511
let next_result = infer_serialize(next_value, serializer, include, exclude, extra);
509-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
510512
return next_result;
511513
} else if extra.serialize_unknown {
512514
serializer.serialize_str(&serialize_unknown(value))
@@ -520,7 +522,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
520522
}
521523
}
522524
};
523-
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
524525
ser_result
525526
}
526527

src/serializers/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::py_gc::PyGcTraverse;
1010

1111
use config::SerializationConfig;
1212
pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
13-
use extra::{CollectWarnings, SerRecursionGuard};
13+
use extra::{CollectWarnings, SerRecursionState};
1414
pub(crate) use extra::{Extra, SerMode, SerializationState};
1515
pub use shared::CombinedSerializer;
1616
use shared::{to_json_bytes, BuildSerializer, TypeSerializer};
@@ -52,7 +52,7 @@ impl SchemaSerializer {
5252
exclude_defaults: bool,
5353
exclude_none: bool,
5454
round_trip: bool,
55-
rec_guard: &'a SerRecursionGuard,
55+
rec_guard: &'a SerRecursionState,
5656
serialize_unknown: bool,
5757
fallback: Option<&'a PyAny>,
5858
) -> Extra<'b> {
@@ -113,7 +113,7 @@ impl SchemaSerializer {
113113
) -> PyResult<PyObject> {
114114
let mode: SerMode = mode.into();
115115
let warnings = CollectWarnings::new(warnings);
116-
let rec_guard = SerRecursionGuard::default();
116+
let rec_guard = SerRecursionState::default();
117117
let extra = self.build_extra(
118118
py,
119119
&mode,
@@ -152,7 +152,7 @@ impl SchemaSerializer {
152152
fallback: Option<&PyAny>,
153153
) -> PyResult<PyObject> {
154154
let warnings = CollectWarnings::new(warnings);
155-
let rec_guard = SerRecursionGuard::default();
155+
let rec_guard = SerRecursionState::default();
156156
let extra = self.build_extra(
157157
py,
158158
&SerMode::Json,

src/serializers/type_serializers/definitions.rs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,12 @@ impl TypeSerializer for DefinitionRefSerializer {
6666
value: &PyAny,
6767
include: Option<&PyAny>,
6868
exclude: Option<&PyAny>,
69-
extra: &Extra,
69+
mut extra: &Extra,
7070
) -> PyResult<PyObject> {
7171
self.definition.read(|comb_serializer| {
7272
let comb_serializer = comb_serializer.unwrap();
73-
let value_id = extra.rec_guard.add(value, self.definition.id())?;
74-
let r = comb_serializer.to_python(value, include, exclude, extra);
75-
extra.rec_guard.pop(value_id, self.definition.id());
76-
r
73+
let mut guard = extra.recursion_guard(value, self.definition.id())?;
74+
comb_serializer.to_python(value, include, exclude, guard.state())
7775
})
7876
}
7977

@@ -87,17 +85,14 @@ impl TypeSerializer for DefinitionRefSerializer {
8785
serializer: S,
8886
include: Option<&PyAny>,
8987
exclude: Option<&PyAny>,
90-
extra: &Extra,
88+
mut extra: &Extra,
9189
) -> Result<S::Ok, S::Error> {
9290
self.definition.read(|comb_serializer| {
9391
let comb_serializer = comb_serializer.unwrap();
94-
let value_id = extra
95-
.rec_guard
96-
.add(value, self.definition.id())
92+
let mut guard = extra
93+
.recursion_guard(value, self.definition.id())
9794
.map_err(py_err_se_err)?;
98-
let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra);
99-
extra.rec_guard.pop(value_id, self.definition.id());
100-
r
95+
comb_serializer.serde_serialize(value, serializer, include, exclude, guard.state())
10196
})
10297
}
10398

0 commit comments

Comments
 (0)