1
1
use ahash:: AHashSet ;
2
+ use std:: mem:: MaybeUninit ;
2
3
3
4
type RecursionKey = (
4
5
// Identifier for the input object, e.g. the id() of a Python dict
@@ -13,56 +14,147 @@ type RecursionKey = (
13
14
/// It's used in `validators/definition` to detect when a reference is reused within itself.
14
15
#[ derive( Debug , Clone , Default ) ]
15
16
pub struct RecursionGuard {
16
- ids : Option < AHashSet < RecursionKey > > ,
17
+ ids : RecursionStack ,
17
18
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
18
19
// use one number for all validators
19
- depth : u16 ,
20
+ depth : u8 ,
20
21
}
21
22
22
23
// A hard limit to avoid stack overflows when rampant recursion occurs
23
- pub const RECURSION_GUARD_LIMIT : u16 = if cfg ! ( any( target_family = "wasm" , all( windows, PyPy ) ) ) {
24
+ pub const RECURSION_GUARD_LIMIT : u8 = if cfg ! ( any( target_family = "wasm" , all( windows, PyPy ) ) ) {
24
25
// wasm and windows PyPy have very limited stack sizes
25
- 50
26
+ 49
26
27
} else if cfg ! ( any( PyPy , windows) ) {
27
28
// PyPy and Windows in general have more restricted stack space
28
- 100
29
+ 99
29
30
} else {
30
31
255
31
32
} ;
32
33
33
34
impl RecursionGuard {
34
- // insert a new id into the set, return whether the set already had the id in it
35
- pub fn contains_or_insert ( & mut self , obj_id : usize , node_id : usize ) -> bool {
36
- match self . ids {
37
- // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
38
- // "If the set did not have this value present, `true` is returned."
39
- Some ( ref mut set) => !set. insert ( ( obj_id, node_id) ) ,
40
- None => {
41
- let mut set: AHashSet < RecursionKey > = AHashSet :: with_capacity ( 10 ) ;
42
- set. insert ( ( obj_id, node_id) ) ;
43
- self . ids = Some ( set) ;
44
- false
45
- }
46
- }
35
+ // insert a new value
36
+ // * return `false` if the stack already had it in it
37
+ // * return `true` if the stack didn't have it in it and it was inserted
38
+ pub fn insert ( & mut self , obj_id : usize , node_id : usize ) -> bool {
39
+ self . ids . insert ( ( obj_id, node_id) )
47
40
}
48
41
49
42
// see #143 this is used as a backup in case the identity check recursion guard fails
50
43
#[ must_use]
44
+ #[ cfg( any( target_family = "wasm" , windows, PyPy ) ) ]
51
45
pub fn incr_depth ( & mut self ) -> bool {
52
- self . depth += 1 ;
53
- self . depth >= RECURSION_GUARD_LIMIT
46
+ // use saturating_add as it's faster (since there's no error path)
47
+ // and the RECURSION_GUARD_LIMIT check will be hit before it overflows
48
+ debug_assert ! ( RECURSION_GUARD_LIMIT < 255 ) ;
49
+ self . depth = self . depth . saturating_add ( 1 ) ;
50
+ self . depth > RECURSION_GUARD_LIMIT
51
+ }
52
+
53
+ #[ must_use]
54
+ #[ cfg( not( any( target_family = "wasm" , windows, PyPy ) ) ) ]
55
+ pub fn incr_depth ( & mut self ) -> bool {
56
+ debug_assert_eq ! ( RECURSION_GUARD_LIMIT , 255 ) ;
57
+ // use checked_add to check if we've hit the limit
58
+ if let Some ( depth) = self . depth . checked_add ( 1 ) {
59
+ self . depth = depth;
60
+ false
61
+ } else {
62
+ true
63
+ }
54
64
}
55
65
56
66
pub fn decr_depth ( & mut self ) {
57
- self . depth -= 1 ;
67
+ // for the same reason as incr_depth, use saturating_sub
68
+ self . depth = self . depth . saturating_sub ( 1 ) ;
58
69
}
59
70
60
71
pub fn remove ( & mut self , obj_id : usize , node_id : usize ) {
61
- match self . ids {
62
- Some ( ref mut set) => {
63
- set. remove ( & ( obj_id, node_id) ) ;
72
+ self . ids . remove ( & ( obj_id, node_id) ) ;
73
+ }
74
+ }
75
+
76
+ // trial and error suggests this is a good value, going higher causes array lookups to get significantly slower
77
+ const ARRAY_SIZE : usize = 16 ;
78
+
79
+ #[ derive( Debug , Clone ) ]
80
+ enum RecursionStack {
81
+ Array {
82
+ data : [ MaybeUninit < RecursionKey > ; ARRAY_SIZE ] ,
83
+ len : usize ,
84
+ } ,
85
+ Set ( AHashSet < RecursionKey > ) ,
86
+ }
87
+
88
+ impl Default for RecursionStack {
89
+ fn default ( ) -> Self {
90
+ Self :: Array {
91
+ data : std:: array:: from_fn ( |_| MaybeUninit :: uninit ( ) ) ,
92
+ len : 0 ,
93
+ }
94
+ }
95
+ }
96
+
97
+ impl RecursionStack {
98
+ // insert a new value
99
+ // * return `false` if the stack already had it in it
100
+ // * return `true` if the stack didn't have it in it and it was inserted
101
+ pub fn insert ( & mut self , v : RecursionKey ) -> bool {
102
+ match self {
103
+ Self :: Array { data, len } => {
104
+ if * len < ARRAY_SIZE {
105
+ for value in data. iter ( ) . take ( * len) {
106
+ // Safety: reading values within bounds
107
+ if unsafe { value. assume_init ( ) } == v {
108
+ return false ;
109
+ }
110
+ }
111
+
112
+ data[ * len] . write ( v) ;
113
+ * len += 1 ;
114
+ true
115
+ } else {
116
+ let mut set = AHashSet :: with_capacity ( ARRAY_SIZE + 1 ) ;
117
+ for existing in data. iter ( ) {
118
+ // Safety: the array is fully initialized
119
+ set. insert ( unsafe { existing. assume_init ( ) } ) ;
120
+ }
121
+ let inserted = set. insert ( v) ;
122
+ * self = Self :: Set ( set) ;
123
+ inserted
124
+ }
125
+ }
126
+ // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
127
+ // "If the set did not have this value present, `true` is returned."
128
+ Self :: Set ( set) => set. insert ( v) ,
129
+ }
130
+ }
131
+
132
+ pub fn remove ( & mut self , v : & RecursionKey ) {
133
+ match self {
134
+ Self :: Array { data, len } => {
135
+ * len = len. checked_sub ( 1 ) . expect ( "remove from empty recursion guard" ) ;
136
+ // Safety: this is reading what was the back of the initialized array
137
+ let removed = unsafe { data. get_unchecked_mut ( * len) } ;
138
+ assert ! ( unsafe { removed. assume_init_ref( ) } == v, "remove did not match insert" ) ;
139
+ // this should compile away to a noop
140
+ unsafe { std:: ptr:: drop_in_place ( removed. as_mut_ptr ( ) ) }
141
+ }
142
+ Self :: Set ( set) => {
143
+ set. remove ( v) ;
144
+ }
145
+ }
146
+ }
147
+ }
148
+
149
+ impl Drop for RecursionStack {
150
+ fn drop ( & mut self ) {
151
+ // This should compile away to a noop as Recursion>Key doesn't implement Drop, but it seemed
152
+ // desirable to leave this in for safety in case that should change in the future
153
+ if let Self :: Array { data, len } = self {
154
+ for value in data. iter_mut ( ) . take ( * len) {
155
+ // Safety: reading values within bounds
156
+ unsafe { std:: ptr:: drop_in_place ( value. as_mut_ptr ( ) ) } ;
64
157
}
65
- None => unreachable ! ( ) ,
66
- } ;
158
+ }
67
159
}
68
160
}
0 commit comments