Skip to content

Commit e1cb0eb

Browse files
improve performance of recursion guard (#1156)
Co-authored-by: David Hewitt <[email protected]> Co-authored-by: David Hewitt <[email protected]>
1 parent d7cf72d commit e1cb0eb

File tree

4 files changed

+134
-42
lines changed

4 files changed

+134
-42
lines changed

src/recursion_guard.rs

Lines changed: 118 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use ahash::AHashSet;
2+
use std::mem::MaybeUninit;
23

34
type RecursionKey = (
45
// Identifier for the input object, e.g. the id() of a Python dict
@@ -13,56 +14,147 @@ type RecursionKey = (
1314
/// It's used in `validators/definition` to detect when a reference is reused within itself.
1415
#[derive(Debug, Clone, Default)]
1516
pub struct RecursionGuard {
16-
ids: Option<AHashSet<RecursionKey>>,
17+
ids: RecursionStack,
1718
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
1819
// use one number for all validators
19-
depth: u16,
20+
depth: u8,
2021
}
2122

2223
// A hard limit to avoid stack overflows when rampant recursion occurs
23-
pub const RECURSION_GUARD_LIMIT: u16 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) {
24+
pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) {
2425
// wasm and windows PyPy have very limited stack sizes
25-
50
26+
49
2627
} else if cfg!(any(PyPy, windows)) {
2728
// PyPy and Windows in general have more restricted stack space
28-
100
29+
99
2930
} else {
3031
255
3132
};
3233

3334
impl RecursionGuard {
34-
// insert a new id into the set, return whether the set already had the id in it
35-
pub fn contains_or_insert(&mut self, obj_id: usize, node_id: usize) -> bool {
36-
match self.ids {
37-
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
38-
// "If the set did not have this value present, `true` is returned."
39-
Some(ref mut set) => !set.insert((obj_id, node_id)),
40-
None => {
41-
let mut set: AHashSet<RecursionKey> = AHashSet::with_capacity(10);
42-
set.insert((obj_id, node_id));
43-
self.ids = Some(set);
44-
false
45-
}
46-
}
35+
// insert a new value
36+
// * return `false` if the stack already had it in it
37+
// * return `true` if the stack didn't have it in it and it was inserted
38+
pub fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
39+
self.ids.insert((obj_id, node_id))
4740
}
4841

4942
// see #143 this is used as a backup in case the identity check recursion guard fails
5043
#[must_use]
44+
#[cfg(any(target_family = "wasm", windows, PyPy))]
5145
pub fn incr_depth(&mut self) -> bool {
52-
self.depth += 1;
53-
self.depth >= RECURSION_GUARD_LIMIT
46+
// use saturating_add as it's faster (since there's no error path)
47+
// and the RECURSION_GUARD_LIMIT check will be hit before it overflows
48+
debug_assert!(RECURSION_GUARD_LIMIT < 255);
49+
self.depth = self.depth.saturating_add(1);
50+
self.depth > RECURSION_GUARD_LIMIT
51+
}
52+
53+
#[must_use]
54+
#[cfg(not(any(target_family = "wasm", windows, PyPy)))]
55+
pub fn incr_depth(&mut self) -> bool {
56+
debug_assert_eq!(RECURSION_GUARD_LIMIT, 255);
57+
// use checked_add to check if we've hit the limit
58+
if let Some(depth) = self.depth.checked_add(1) {
59+
self.depth = depth;
60+
false
61+
} else {
62+
true
63+
}
5464
}
5565

5666
pub fn decr_depth(&mut self) {
57-
self.depth -= 1;
67+
// for the same reason as incr_depth, use saturating_sub
68+
self.depth = self.depth.saturating_sub(1);
5869
}
5970

6071
pub fn remove(&mut self, obj_id: usize, node_id: usize) {
61-
match self.ids {
62-
Some(ref mut set) => {
63-
set.remove(&(obj_id, node_id));
72+
self.ids.remove(&(obj_id, node_id));
73+
}
74+
}
75+
76+
// trial and error suggests this is a good value, going higher causes array lookups to get significantly slower
77+
const ARRAY_SIZE: usize = 16;
78+
79+
#[derive(Debug, Clone)]
80+
enum RecursionStack {
81+
Array {
82+
data: [MaybeUninit<RecursionKey>; ARRAY_SIZE],
83+
len: usize,
84+
},
85+
Set(AHashSet<RecursionKey>),
86+
}
87+
88+
impl Default for RecursionStack {
89+
fn default() -> Self {
90+
Self::Array {
91+
data: std::array::from_fn(|_| MaybeUninit::uninit()),
92+
len: 0,
93+
}
94+
}
95+
}
96+
97+
impl RecursionStack {
98+
// insert a new value
99+
// * return `false` if the stack already had it in it
100+
// * return `true` if the stack didn't have it in it and it was inserted
101+
pub fn insert(&mut self, v: RecursionKey) -> bool {
102+
match self {
103+
Self::Array { data, len } => {
104+
if *len < ARRAY_SIZE {
105+
for value in data.iter().take(*len) {
106+
// Safety: reading values within bounds
107+
if unsafe { value.assume_init() } == v {
108+
return false;
109+
}
110+
}
111+
112+
data[*len].write(v);
113+
*len += 1;
114+
true
115+
} else {
116+
let mut set = AHashSet::with_capacity(ARRAY_SIZE + 1);
117+
for existing in data.iter() {
118+
// Safety: the array is fully initialized
119+
set.insert(unsafe { existing.assume_init() });
120+
}
121+
let inserted = set.insert(v);
122+
*self = Self::Set(set);
123+
inserted
124+
}
125+
}
126+
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
127+
// "If the set did not have this value present, `true` is returned."
128+
Self::Set(set) => set.insert(v),
129+
}
130+
}
131+
132+
pub fn remove(&mut self, v: &RecursionKey) {
133+
match self {
134+
Self::Array { data, len } => {
135+
*len = len.checked_sub(1).expect("remove from empty recursion guard");
136+
// Safety: this is reading what was the back of the initialized array
137+
let removed = unsafe { data.get_unchecked_mut(*len) };
138+
assert!(unsafe { removed.assume_init_ref() } == v, "remove did not match insert");
139+
// this should compile away to a noop
140+
unsafe { std::ptr::drop_in_place(removed.as_mut_ptr()) }
141+
}
142+
Self::Set(set) => {
143+
set.remove(v);
144+
}
145+
}
146+
}
147+
}
148+
149+
impl Drop for RecursionStack {
150+
fn drop(&mut self) {
151+
// This should compile away to a noop as Recursion>Key doesn't implement Drop, but it seemed
152+
// desirable to leave this in for safety in case that should change in the future
153+
if let Self::Array { data, len } = self {
154+
for value in data.iter_mut().take(*len) {
155+
// Safety: reading values within bounds
156+
unsafe { std::ptr::drop_in_place(value.as_mut_ptr()) };
64157
}
65-
None => unreachable!(),
66-
};
158+
}
67159
}
68160
}

src/serializers/extra.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -346,17 +346,17 @@ pub struct SerRecursionGuard {
346346

347347
impl SerRecursionGuard {
348348
pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
349-
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
350-
// "If the set did not have this value present, `true` is returned."
351349
let id = value.as_ptr() as usize;
352350
let mut guard = self.guard.borrow_mut();
353351

354-
if guard.contains_or_insert(id, def_ref_id) {
355-
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
356-
} else if guard.incr_depth() {
357-
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
352+
if guard.insert(id, def_ref_id) {
353+
if guard.incr_depth() {
354+
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
355+
} else {
356+
Ok(id)
357+
}
358358
} else {
359-
Ok(id)
359+
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
360360
}
361361
}
362362

src/validators/definitions.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,17 @@ impl Validator for DefinitionRefValidator {
7676
self.definition.read(|validator| {
7777
let validator = validator.unwrap();
7878
if let Some(id) = input.identity() {
79-
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
80-
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
81-
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
82-
} else {
79+
if state.recursion_guard.insert(id, self.definition.id()) {
8380
if state.recursion_guard.incr_depth() {
8481
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
8582
}
8683
let output = validator.validate(py, input, state);
8784
state.recursion_guard.remove(id, self.definition.id());
8885
state.recursion_guard.decr_depth();
8986
output
87+
} else {
88+
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
89+
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
9090
}
9191
} else {
9292
validator.validate(py, input, state)
@@ -105,17 +105,17 @@ impl Validator for DefinitionRefValidator {
105105
self.definition.read(|validator| {
106106
let validator = validator.unwrap();
107107
if let Some(id) = obj.identity() {
108-
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
109-
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
110-
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
111-
} else {
108+
if state.recursion_guard.insert(id, self.definition.id()) {
112109
if state.recursion_guard.incr_depth() {
113110
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
114111
}
115112
let output = validator.validate_assignment(py, obj, field_name, field_value, state);
116113
state.recursion_guard.remove(id, self.definition.id());
117114
state.recursion_guard.decr_depth();
118115
output
116+
} else {
117+
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
118+
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
119119
}
120120
} else {
121121
validator.validate_assignment(py, obj, field_name, field_value, state)

tests/serializers/test_any.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def fallback_func(obj):
371371
f = FoobarCount(0)
372372
v = 0
373373
# when recursion is detected and we're in mode python, we just return the value
374-
expected_visits = pydantic_core._pydantic_core._recursion_limit - 1
374+
expected_visits = pydantic_core._pydantic_core._recursion_limit
375375
assert any_serializer.to_python(f, fallback=fallback_func) == HasRepr(f'<FoobarCount {expected_visits} repr>')
376376

377377
with pytest.raises(ValueError, match=r'Circular reference detected \(depth exceeded\)'):

0 commit comments

Comments
 (0)