Skip to content

Commit cd0f9d1

Browse files
gh-89967: make WeakKeyDictionary and WeakValueDictionary thread safe (#125325)
Make `WeakKeyDictionary` and `WeakValueDictionary` thread safe by copying the underlying the dict before iterating over it.
1 parent 0848932 commit cd0f9d1

File tree

3 files changed

+50
-174
lines changed

3 files changed

+50
-174
lines changed

Lib/_weakrefset.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,6 @@
88
__all__ = ['WeakSet']
99

1010

11-
class _IterationGuard:
12-
# This context manager registers itself in the current iterators of the
13-
# weak container, such as to delay all removals until the context manager
14-
# exits.
15-
# This technique should be relatively thread-safe (since sets are).
16-
17-
def __init__(self, weakcontainer):
18-
# Don't create cycles
19-
self.weakcontainer = ref(weakcontainer)
20-
21-
def __enter__(self):
22-
w = self.weakcontainer()
23-
if w is not None:
24-
w._iterating.add(self)
25-
return self
26-
27-
def __exit__(self, e, t, b):
28-
w = self.weakcontainer()
29-
if w is not None:
30-
s = w._iterating
31-
s.remove(self)
32-
if not s:
33-
w._commit_removals()
34-
35-
3611
class WeakSet:
3712
def __init__(self, data=None):
3813
self.data = set()

Lib/weakref.py

Lines changed: 49 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
ReferenceType,
2020
_remove_dead_weakref)
2121

22-
from _weakrefset import WeakSet, _IterationGuard
22+
from _weakrefset import WeakSet
2323

2424
import _collections_abc # Import after _weakref to avoid circular import.
2525
import sys
@@ -105,53 +105,27 @@ def __init__(self, other=(), /, **kw):
105105
def remove(wr, selfref=ref(self), _atomic_removal=_remove_dead_weakref):
106106
self = selfref()
107107
if self is not None:
108-
if self._iterating:
109-
self._pending_removals.append(wr.key)
110-
else:
111-
# Atomic removal is necessary since this function
112-
# can be called asynchronously by the GC
113-
_atomic_removal(self.data, wr.key)
108+
# Atomic removal is necessary since this function
109+
# can be called asynchronously by the GC
110+
_atomic_removal(self.data, wr.key)
114111
self._remove = remove
115-
# A list of keys to be removed
116-
self._pending_removals = []
117-
self._iterating = set()
118112
self.data = {}
119113
self.update(other, **kw)
120114

121-
def _commit_removals(self, _atomic_removal=_remove_dead_weakref):
122-
pop = self._pending_removals.pop
123-
d = self.data
124-
# We shouldn't encounter any KeyError, because this method should
125-
# always be called *before* mutating the dict.
126-
while True:
127-
try:
128-
key = pop()
129-
except IndexError:
130-
return
131-
_atomic_removal(d, key)
132-
133115
def __getitem__(self, key):
134-
if self._pending_removals:
135-
self._commit_removals()
136116
o = self.data[key]()
137117
if o is None:
138118
raise KeyError(key)
139119
else:
140120
return o
141121

142122
def __delitem__(self, key):
143-
if self._pending_removals:
144-
self._commit_removals()
145123
del self.data[key]
146124

147125
def __len__(self):
148-
if self._pending_removals:
149-
self._commit_removals()
150126
return len(self.data)
151127

152128
def __contains__(self, key):
153-
if self._pending_removals:
154-
self._commit_removals()
155129
try:
156130
o = self.data[key]()
157131
except KeyError:
@@ -162,38 +136,28 @@ def __repr__(self):
162136
return "<%s at %#x>" % (self.__class__.__name__, id(self))
163137

164138
def __setitem__(self, key, value):
165-
if self._pending_removals:
166-
self._commit_removals()
167139
self.data[key] = KeyedRef(value, self._remove, key)
168140

169141
def copy(self):
170-
if self._pending_removals:
171-
self._commit_removals()
172142
new = WeakValueDictionary()
173-
with _IterationGuard(self):
174-
for key, wr in self.data.items():
175-
o = wr()
176-
if o is not None:
177-
new[key] = o
143+
for key, wr in self.data.copy().items():
144+
o = wr()
145+
if o is not None:
146+
new[key] = o
178147
return new
179148

180149
__copy__ = copy
181150

182151
def __deepcopy__(self, memo):
183152
from copy import deepcopy
184-
if self._pending_removals:
185-
self._commit_removals()
186153
new = self.__class__()
187-
with _IterationGuard(self):
188-
for key, wr in self.data.items():
189-
o = wr()
190-
if o is not None:
191-
new[deepcopy(key, memo)] = o
154+
for key, wr in self.data.copy().items():
155+
o = wr()
156+
if o is not None:
157+
new[deepcopy(key, memo)] = o
192158
return new
193159

194160
def get(self, key, default=None):
195-
if self._pending_removals:
196-
self._commit_removals()
197161
try:
198162
wr = self.data[key]
199163
except KeyError:
@@ -207,21 +171,15 @@ def get(self, key, default=None):
207171
return o
208172

209173
def items(self):
210-
if self._pending_removals:
211-
self._commit_removals()
212-
with _IterationGuard(self):
213-
for k, wr in self.data.items():
214-
v = wr()
215-
if v is not None:
216-
yield k, v
174+
for k, wr in self.data.copy().items():
175+
v = wr()
176+
if v is not None:
177+
yield k, v
217178

218179
def keys(self):
219-
if self._pending_removals:
220-
self._commit_removals()
221-
with _IterationGuard(self):
222-
for k, wr in self.data.items():
223-
if wr() is not None:
224-
yield k
180+
for k, wr in self.data.copy().items():
181+
if wr() is not None:
182+
yield k
225183

226184
__iter__ = keys
227185

@@ -235,32 +193,22 @@ def itervaluerefs(self):
235193
keep the values around longer than needed.
236194
237195
"""
238-
if self._pending_removals:
239-
self._commit_removals()
240-
with _IterationGuard(self):
241-
yield from self.data.values()
196+
yield from self.data.copy().values()
242197

243198
def values(self):
244-
if self._pending_removals:
245-
self._commit_removals()
246-
with _IterationGuard(self):
247-
for wr in self.data.values():
248-
obj = wr()
249-
if obj is not None:
250-
yield obj
199+
for wr in self.data.copy().values():
200+
obj = wr()
201+
if obj is not None:
202+
yield obj
251203

252204
def popitem(self):
253-
if self._pending_removals:
254-
self._commit_removals()
255205
while True:
256206
key, wr = self.data.popitem()
257207
o = wr()
258208
if o is not None:
259209
return key, o
260210

261211
def pop(self, key, *args):
262-
if self._pending_removals:
263-
self._commit_removals()
264212
try:
265213
o = self.data.pop(key)()
266214
except KeyError:
@@ -279,16 +227,12 @@ def setdefault(self, key, default=None):
279227
except KeyError:
280228
o = None
281229
if o is None:
282-
if self._pending_removals:
283-
self._commit_removals()
284230
self.data[key] = KeyedRef(default, self._remove, key)
285231
return default
286232
else:
287233
return o
288234

289235
def update(self, other=None, /, **kwargs):
290-
if self._pending_removals:
291-
self._commit_removals()
292236
d = self.data
293237
if other is not None:
294238
if not hasattr(other, "items"):
@@ -308,9 +252,7 @@ def valuerefs(self):
308252
keep the values around longer than needed.
309253
310254
"""
311-
if self._pending_removals:
312-
self._commit_removals()
313-
return list(self.data.values())
255+
return list(self.data.copy().values())
314256

315257
def __ior__(self, other):
316258
self.update(other)
@@ -369,57 +311,22 @@ def __init__(self, dict=None):
369311
def remove(k, selfref=ref(self)):
370312
self = selfref()
371313
if self is not None:
372-
if self._iterating:
373-
self._pending_removals.append(k)
374-
else:
375-
try:
376-
del self.data[k]
377-
except KeyError:
378-
pass
314+
try:
315+
del self.data[k]
316+
except KeyError:
317+
pass
379318
self._remove = remove
380-
# A list of dead weakrefs (keys to be removed)
381-
self._pending_removals = []
382-
self._iterating = set()
383-
self._dirty_len = False
384319
if dict is not None:
385320
self.update(dict)
386321

387-
def _commit_removals(self):
388-
# NOTE: We don't need to call this method before mutating the dict,
389-
# because a dead weakref never compares equal to a live weakref,
390-
# even if they happened to refer to equal objects.
391-
# However, it means keys may already have been removed.
392-
pop = self._pending_removals.pop
393-
d = self.data
394-
while True:
395-
try:
396-
key = pop()
397-
except IndexError:
398-
return
399-
400-
try:
401-
del d[key]
402-
except KeyError:
403-
pass
404-
405-
def _scrub_removals(self):
406-
d = self.data
407-
self._pending_removals = [k for k in self._pending_removals if k in d]
408-
self._dirty_len = False
409-
410322
def __delitem__(self, key):
411-
self._dirty_len = True
412323
del self.data[ref(key)]
413324

414325
def __getitem__(self, key):
415326
return self.data[ref(key)]
416327

417328
def __len__(self):
418-
if self._dirty_len and self._pending_removals:
419-
# self._pending_removals may still contain keys which were
420-
# explicitly removed, we have to scrub them (see issue #21173).
421-
self._scrub_removals()
422-
return len(self.data) - len(self._pending_removals)
329+
return len(self.data)
423330

424331
def __repr__(self):
425332
return "<%s at %#x>" % (self.__class__.__name__, id(self))
@@ -429,23 +336,21 @@ def __setitem__(self, key, value):
429336

430337
def copy(self):
431338
new = WeakKeyDictionary()
432-
with _IterationGuard(self):
433-
for key, value in self.data.items():
434-
o = key()
435-
if o is not None:
436-
new[o] = value
339+
for key, value in self.data.copy().items():
340+
o = key()
341+
if o is not None:
342+
new[o] = value
437343
return new
438344

439345
__copy__ = copy
440346

441347
def __deepcopy__(self, memo):
442348
from copy import deepcopy
443349
new = self.__class__()
444-
with _IterationGuard(self):
445-
for key, value in self.data.items():
446-
o = key()
447-
if o is not None:
448-
new[o] = deepcopy(value, memo)
350+
for key, value in self.data.copy().items():
351+
o = key()
352+
if o is not None:
353+
new[o] = deepcopy(value, memo)
449354
return new
450355

451356
def get(self, key, default=None):
@@ -459,26 +364,23 @@ def __contains__(self, key):
459364
return wr in self.data
460365

461366
def items(self):
462-
with _IterationGuard(self):
463-
for wr, value in self.data.items():
464-
key = wr()
465-
if key is not None:
466-
yield key, value
367+
for wr, value in self.data.copy().items():
368+
key = wr()
369+
if key is not None:
370+
yield key, value
467371

468372
def keys(self):
469-
with _IterationGuard(self):
470-
for wr in self.data:
471-
obj = wr()
472-
if obj is not None:
473-
yield obj
373+
for wr in self.data.copy():
374+
obj = wr()
375+
if obj is not None:
376+
yield obj
474377

475378
__iter__ = keys
476379

477380
def values(self):
478-
with _IterationGuard(self):
479-
for wr, value in self.data.items():
480-
if wr() is not None:
481-
yield value
381+
for wr, value in self.data.copy().items():
382+
if wr() is not None:
383+
yield value
482384

483385
def keyrefs(self):
484386
"""Return a list of weak references to the keys.
@@ -493,15 +395,13 @@ def keyrefs(self):
493395
return list(self.data)
494396

495397
def popitem(self):
496-
self._dirty_len = True
497398
while True:
498399
key, value = self.data.popitem()
499400
o = key()
500401
if o is not None:
501402
return o, value
502403

503404
def pop(self, key, *args):
504-
self._dirty_len = True
505405
return self.data.pop(ref(key), *args)
506406

507407
def setdefault(self, key, default=None):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make :class:`~weakref.WeakKeyDictionary` and :class:`~weakref.WeakValueDictionary` safe against concurrent mutations from other threads. Patch by Kumar Aditya.

0 commit comments

Comments
 (0)