Skip to content

Commit 958064f

Browse files
bpo-39421: Fix posible crash in heapq with custom comparison operators (GH-18118)
* bpo-39421: Fix posible crash in heapq with custom comparison operators * fixup! bpo-39421: Fix posible crash in heapq with custom comparison operators * fixup! fixup! bpo-39421: Fix posible crash in heapq with custom comparison operators (cherry picked from commit 79f89e6) Co-authored-by: Pablo Galindo <[email protected]>
1 parent 36968c1 commit 958064f

File tree

3 files changed

+59
-9
lines changed

3 files changed

+59
-9
lines changed

Lib/test/test_heapq.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,37 @@ def test_heappop_mutating_heap(self):
414414
with self.assertRaises((IndexError, RuntimeError)):
415415
self.module.heappop(heap)
416416

417+
def test_comparison_operator_modifiying_heap(self):
418+
# See bpo-39421: Strong references need to be taken
419+
# when comparing objects as they can alter the heap
420+
class EvilClass(int):
421+
def __lt__(self, o):
422+
heap.clear()
423+
return NotImplemented
424+
425+
heap = []
426+
self.module.heappush(heap, EvilClass(0))
427+
self.assertRaises(IndexError, self.module.heappushpop, heap, 1)
428+
429+
def test_comparison_operator_modifiying_heap_two_heaps(self):
430+
431+
class h(int):
432+
def __lt__(self, o):
433+
list2.clear()
434+
return NotImplemented
435+
436+
class g(int):
437+
def __lt__(self, o):
438+
list1.clear()
439+
return NotImplemented
440+
441+
list1, list2 = [], []
442+
443+
self.module.heappush(list1, h(0))
444+
self.module.heappush(list2, g(0))
445+
446+
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
447+
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))
417448

418449
class TestErrorHandlingPython(TestErrorHandling, TestCase):
419450
module = py_heapq
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fix possible crashes when operating with the functions in the :mod:`heapq`
2+
module and custom comparison operators.

Modules/_heapqmodule.c

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ siftdown(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos)
2929
while (pos > startpos) {
3030
parentpos = (pos - 1) >> 1;
3131
parent = arr[parentpos];
32+
Py_INCREF(newitem);
33+
Py_INCREF(parent);
3234
cmp = PyObject_RichCompareBool(newitem, parent, Py_LT);
35+
Py_DECREF(parent);
36+
Py_DECREF(newitem);
3337
if (cmp < 0)
3438
return -1;
3539
if (size != PyList_GET_SIZE(heap)) {
@@ -71,10 +75,13 @@ siftup(PyListObject *heap, Py_ssize_t pos)
7175
/* Set childpos to index of smaller child. */
7276
childpos = 2*pos + 1; /* leftmost child position */
7377
if (childpos + 1 < endpos) {
74-
cmp = PyObject_RichCompareBool(
75-
arr[childpos],
76-
arr[childpos + 1],
77-
Py_LT);
78+
PyObject* a = arr[childpos];
79+
PyObject* b = arr[childpos + 1];
80+
Py_INCREF(a);
81+
Py_INCREF(b);
82+
cmp = PyObject_RichCompareBool(a, b, Py_LT);
83+
Py_DECREF(a);
84+
Py_DECREF(b);
7885
if (cmp < 0)
7986
return -1;
8087
childpos += ((unsigned)cmp ^ 1); /* increment when cmp==0 */
@@ -229,7 +236,10 @@ heappushpop(PyObject *self, PyObject *args)
229236
return item;
230237
}
231238

232-
cmp = PyObject_RichCompareBool(PyList_GET_ITEM(heap, 0), item, Py_LT);
239+
PyObject* top = PyList_GET_ITEM(heap, 0);
240+
Py_INCREF(top);
241+
cmp = PyObject_RichCompareBool(top, item, Py_LT);
242+
Py_DECREF(top);
233243
if (cmp < 0)
234244
return NULL;
235245
if (cmp == 0) {
@@ -383,7 +393,11 @@ siftdown_max(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos)
383393
while (pos > startpos) {
384394
parentpos = (pos - 1) >> 1;
385395
parent = arr[parentpos];
396+
Py_INCREF(parent);
397+
Py_INCREF(newitem);
386398
cmp = PyObject_RichCompareBool(parent, newitem, Py_LT);
399+
Py_DECREF(parent);
400+
Py_DECREF(newitem);
387401
if (cmp < 0)
388402
return -1;
389403
if (size != PyList_GET_SIZE(heap)) {
@@ -425,10 +439,13 @@ siftup_max(PyListObject *heap, Py_ssize_t pos)
425439
/* Set childpos to index of smaller child. */
426440
childpos = 2*pos + 1; /* leftmost child position */
427441
if (childpos + 1 < endpos) {
428-
cmp = PyObject_RichCompareBool(
429-
arr[childpos + 1],
430-
arr[childpos],
431-
Py_LT);
442+
PyObject* a = arr[childpos + 1];
443+
PyObject* b = arr[childpos];
444+
Py_INCREF(a);
445+
Py_INCREF(b);
446+
cmp = PyObject_RichCompareBool(a, b, Py_LT);
447+
Py_DECREF(a);
448+
Py_DECREF(b);
432449
if (cmp < 0)
433450
return -1;
434451
childpos += ((unsigned)cmp ^ 1); /* increment when cmp==0 */

0 commit comments

Comments
 (0)