@@ -74,12 +74,23 @@ impl MutexMetadata {
74
74
}
75
75
}
76
76
77
- fn pre_lock ( this : & Arc < MutexMetadata > ) {
77
+ // Returns whether we were a recursive lock (only relevant for read)
78
+ fn _pre_lock ( this : & Arc < MutexMetadata > , read : bool ) -> bool {
79
+ let mut inserted = false ;
78
80
MUTEXES_HELD . with ( |held| {
79
81
// For each mutex which is currently locked, check that no mutex's locked-before
80
82
// set includes the mutex we're about to lock, which would imply a lockorder
81
83
// inversion.
82
84
for locked in held. borrow ( ) . iter ( ) {
85
+ if read && * locked == * this {
86
+ // Recursive read locks are explicitly allowed
87
+ return ;
88
+ }
89
+ }
90
+ for locked in held. borrow ( ) . iter ( ) {
91
+ if !read && * locked == * this {
92
+ panic ! ( "Tried to lock a mutex while it was held!" ) ;
93
+ }
83
94
for locked_dep in locked. locked_before . lock ( ) . unwrap ( ) . iter ( ) {
84
95
if * locked_dep == * this {
85
96
#[ cfg( feature = "backtrace" ) ]
@@ -92,9 +103,14 @@ impl MutexMetadata {
92
103
this. locked_before . lock ( ) . unwrap ( ) . insert ( Arc :: clone ( locked) ) ;
93
104
}
94
105
held. borrow_mut ( ) . insert ( Arc :: clone ( this) ) ;
106
+ inserted = true ;
95
107
} ) ;
108
+ inserted
96
109
}
97
110
111
+ fn pre_lock ( this : & Arc < MutexMetadata > ) { Self :: _pre_lock ( this, false ) ; }
112
+ fn pre_read_lock ( this : & Arc < MutexMetadata > ) -> bool { Self :: _pre_lock ( this, true ) }
113
+
98
114
fn try_locked ( this : & Arc < MutexMetadata > ) {
99
115
MUTEXES_HELD . with ( |held| {
100
116
// Since a try-lock will simply fail if the lock is held already, we do not
@@ -171,54 +187,178 @@ impl<T> Mutex<T> {
171
187
}
172
188
}
173
189
174
- pub struct RwLock < T : ?Sized > {
175
- inner : StdRwLock < T >
190
+ pub struct RwLock < T : Sized > {
191
+ inner : StdRwLock < T > ,
192
+ deps : Arc < MutexMetadata > ,
176
193
}
177
194
178
- pub struct RwLockReadGuard < ' a , T : ?Sized + ' a > {
195
+ pub struct RwLockReadGuard < ' a , T : Sized + ' a > {
196
+ mutex : & ' a RwLock < T > ,
197
+ first_lock : bool ,
179
198
lock : StdRwLockReadGuard < ' a , T > ,
180
199
}
181
200
182
- pub struct RwLockWriteGuard < ' a , T : ?Sized + ' a > {
201
+ pub struct RwLockWriteGuard < ' a , T : Sized + ' a > {
202
+ mutex : & ' a RwLock < T > ,
183
203
lock : StdRwLockWriteGuard < ' a , T > ,
184
204
}
185
205
186
- impl < T : ? Sized > Deref for RwLockReadGuard < ' _ , T > {
206
+ impl < T : Sized > Deref for RwLockReadGuard < ' _ , T > {
187
207
type Target = T ;
188
208
189
209
fn deref ( & self ) -> & T {
190
210
& self . lock . deref ( )
191
211
}
192
212
}
193
213
194
- impl < T : ?Sized > Deref for RwLockWriteGuard < ' _ , T > {
214
+ impl < T : Sized > Drop for RwLockReadGuard < ' _ , T > {
215
+ fn drop ( & mut self ) {
216
+ if !self . first_lock {
217
+ // Note that its not strictly true that the first taken read lock will get unlocked
218
+ // last, but in practice our locks are always taken as RAII, so it should basically
219
+ // always be true.
220
+ return ;
221
+ }
222
+ MUTEXES_HELD . with ( |held| {
223
+ held. borrow_mut ( ) . remove ( & self . mutex . deps ) ;
224
+ } ) ;
225
+ }
226
+ }
227
+
228
+ impl < T : Sized > Deref for RwLockWriteGuard < ' _ , T > {
195
229
type Target = T ;
196
230
197
231
fn deref ( & self ) -> & T {
198
232
& self . lock . deref ( )
199
233
}
200
234
}
201
235
202
- impl < T : ?Sized > DerefMut for RwLockWriteGuard < ' _ , T > {
236
+ impl < T : Sized > Drop for RwLockWriteGuard < ' _ , T > {
237
+ fn drop ( & mut self ) {
238
+ MUTEXES_HELD . with ( |held| {
239
+ held. borrow_mut ( ) . remove ( & self . mutex . deps ) ;
240
+ } ) ;
241
+ }
242
+ }
243
+
244
+ impl < T : Sized > DerefMut for RwLockWriteGuard < ' _ , T > {
203
245
fn deref_mut ( & mut self ) -> & mut T {
204
246
self . lock . deref_mut ( )
205
247
}
206
248
}
207
249
208
250
impl < T > RwLock < T > {
209
251
pub fn new ( inner : T ) -> RwLock < T > {
210
- RwLock { inner : StdRwLock :: new ( inner) }
252
+ RwLock { inner : StdRwLock :: new ( inner) , deps : Arc :: new ( MutexMetadata :: new ( ) ) }
211
253
}
212
254
213
255
pub fn read < ' a > ( & ' a self ) -> LockResult < RwLockReadGuard < ' a , T > > {
214
- self . inner . read ( ) . map ( |lock| RwLockReadGuard { lock } ) . map_err ( |_| ( ) )
256
+ let first_lock = MutexMetadata :: pre_read_lock ( & self . deps ) ;
257
+ self . inner . read ( ) . map ( |lock| RwLockReadGuard { mutex : self , lock, first_lock } ) . map_err ( |_| ( ) )
215
258
}
216
259
217
260
pub fn write < ' a > ( & ' a self ) -> LockResult < RwLockWriteGuard < ' a , T > > {
218
- self . inner . write ( ) . map ( |lock| RwLockWriteGuard { lock } ) . map_err ( |_| ( ) )
261
+ MutexMetadata :: pre_lock ( & self . deps ) ;
262
+ self . inner . write ( ) . map ( |lock| RwLockWriteGuard { mutex : self , lock } ) . map_err ( |_| ( ) )
219
263
}
220
264
221
265
pub fn try_write < ' a > ( & ' a self ) -> LockResult < RwLockWriteGuard < ' a , T > > {
222
- self . inner . try_write ( ) . map ( |lock| RwLockWriteGuard { lock } ) . map_err ( |_| ( ) )
266
+ let res = self . inner . try_write ( ) . map ( |lock| RwLockWriteGuard { mutex : self , lock } ) . map_err ( |_| ( ) ) ;
267
+ if res. is_ok ( ) {
268
+ MutexMetadata :: try_locked ( & self . deps ) ;
269
+ }
270
+ res
271
+ }
272
+ }
273
+
274
+ #[ test]
275
+ #[ should_panic]
276
+ fn recursive_lock_fail ( ) {
277
+ let mutex = Mutex :: new ( ( ) ) ;
278
+ let _a = mutex. lock ( ) . unwrap ( ) ;
279
+ let _b = mutex. lock ( ) . unwrap ( ) ;
280
+ }
281
+
282
+ #[ test]
283
+ fn recursive_read ( ) {
284
+ let lock = RwLock :: new ( ( ) ) ;
285
+ let _a = lock. read ( ) . unwrap ( ) ;
286
+ let _b = lock. read ( ) . unwrap ( ) ;
287
+ }
288
+
289
+ #[ test]
290
+ #[ should_panic]
291
+ fn lockorder_fail ( ) {
292
+ let a = Mutex :: new ( ( ) ) ;
293
+ let b = Mutex :: new ( ( ) ) ;
294
+ {
295
+ let _a = a. lock ( ) . unwrap ( ) ;
296
+ let _b = b. lock ( ) . unwrap ( ) ;
297
+ }
298
+ {
299
+ let _b = b. lock ( ) . unwrap ( ) ;
300
+ let _a = a. lock ( ) . unwrap ( ) ;
301
+ }
302
+ }
303
+
304
+ #[ test]
305
+ #[ should_panic]
306
+ fn write_lockorder_fail ( ) {
307
+ let a = RwLock :: new ( ( ) ) ;
308
+ let b = RwLock :: new ( ( ) ) ;
309
+ {
310
+ let _a = a. write ( ) . unwrap ( ) ;
311
+ let _b = b. write ( ) . unwrap ( ) ;
312
+ }
313
+ {
314
+ let _b = b. write ( ) . unwrap ( ) ;
315
+ let _a = a. write ( ) . unwrap ( ) ;
316
+ }
317
+ }
318
+
319
+ #[ test]
320
+ #[ should_panic]
321
+ fn read_lockorder_fail ( ) {
322
+ let a = RwLock :: new ( ( ) ) ;
323
+ let b = RwLock :: new ( ( ) ) ;
324
+ {
325
+ let _a = a. read ( ) . unwrap ( ) ;
326
+ let _b = b. read ( ) . unwrap ( ) ;
327
+ }
328
+ {
329
+ let _b = b. read ( ) . unwrap ( ) ;
330
+ let _a = a. read ( ) . unwrap ( ) ;
331
+ }
332
+ }
333
+
334
+ #[ test]
335
+ fn read_recurisve_no_lockorder ( ) {
336
+ // Like the above, but note that no lockorder is implied when we recursively read-lock a
337
+ // RwLock, causing this to pass just fine.
338
+ let a = RwLock :: new ( ( ) ) ;
339
+ let b = RwLock :: new ( ( ) ) ;
340
+ let _outer = a. read ( ) . unwrap ( ) ;
341
+ {
342
+ let _a = a. read ( ) . unwrap ( ) ;
343
+ let _b = b. read ( ) . unwrap ( ) ;
344
+ }
345
+ {
346
+ let _b = b. read ( ) . unwrap ( ) ;
347
+ let _a = a. read ( ) . unwrap ( ) ;
348
+ }
349
+ }
350
+
351
+ #[ test]
352
+ #[ should_panic]
353
+ fn read_write_lockorder_fail ( ) {
354
+ let a = RwLock :: new ( ( ) ) ;
355
+ let b = RwLock :: new ( ( ) ) ;
356
+ {
357
+ let _a = a. write ( ) . unwrap ( ) ;
358
+ let _b = b. read ( ) . unwrap ( ) ;
359
+ }
360
+ {
361
+ let _b = b. read ( ) . unwrap ( ) ;
362
+ let _a = a. write ( ) . unwrap ( ) ;
223
363
}
224
364
}
0 commit comments