mirror of https://github.com/python/cpython
bpo-39421: Fix posible crash in heapq with custom comparison operators (GH-18118)
* bpo-39421: Fix posible crash in heapq with custom comparison operators * fixup! bpo-39421: Fix posible crash in heapq with custom comparison operators * fixup! fixup! bpo-39421: Fix posible crash in heapq with custom comparison operators
This commit is contained in:
parent
13bc13960c
commit
79f89e6e5a
|
@ -432,6 +432,37 @@ class TestErrorHandling:
|
|||
with self.assertRaises((IndexError, RuntimeError)):
|
||||
self.module.heappop(heap)
|
||||
|
||||
def test_comparison_operator_modifiying_heap(self):
|
||||
# See bpo-39421: Strong references need to be taken
|
||||
# when comparing objects as they can alter the heap
|
||||
class EvilClass(int):
|
||||
def __lt__(self, o):
|
||||
heap.clear()
|
||||
return NotImplemented
|
||||
|
||||
heap = []
|
||||
self.module.heappush(heap, EvilClass(0))
|
||||
self.assertRaises(IndexError, self.module.heappushpop, heap, 1)
|
||||
|
||||
def test_comparison_operator_modifiying_heap_two_heaps(self):
|
||||
|
||||
class h(int):
|
||||
def __lt__(self, o):
|
||||
list2.clear()
|
||||
return NotImplemented
|
||||
|
||||
class g(int):
|
||||
def __lt__(self, o):
|
||||
list1.clear()
|
||||
return NotImplemented
|
||||
|
||||
list1, list2 = [], []
|
||||
|
||||
self.module.heappush(list1, h(0))
|
||||
self.module.heappush(list2, g(0))
|
||||
|
||||
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
|
||||
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))
|
||||
|
||||
class TestErrorHandlingPython(TestErrorHandling, TestCase):
|
||||
module = py_heapq
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
Fix possible crashes when operating with the functions in the :mod:`heapq`
|
||||
module and custom comparison operators.
|
|
@ -36,7 +36,11 @@ siftdown(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos)
|
|||
while (pos > startpos) {
|
||||
parentpos = (pos - 1) >> 1;
|
||||
parent = arr[parentpos];
|
||||
Py_INCREF(newitem);
|
||||
Py_INCREF(parent);
|
||||
cmp = PyObject_RichCompareBool(newitem, parent, Py_LT);
|
||||
Py_DECREF(parent);
|
||||
Py_DECREF(newitem);
|
||||
if (cmp < 0)
|
||||
return -1;
|
||||
if (size != PyList_GET_SIZE(heap)) {
|
||||
|
@ -78,10 +82,13 @@ siftup(PyListObject *heap, Py_ssize_t pos)
|
|||
/* Set childpos to index of smaller child. */
|
||||
childpos = 2*pos + 1; /* leftmost child position */
|
||||
if (childpos + 1 < endpos) {
|
||||
cmp = PyObject_RichCompareBool(
|
||||
arr[childpos],
|
||||
arr[childpos + 1],
|
||||
Py_LT);
|
||||
PyObject* a = arr[childpos];
|
||||
PyObject* b = arr[childpos + 1];
|
||||
Py_INCREF(a);
|
||||
Py_INCREF(b);
|
||||
cmp = PyObject_RichCompareBool(a, b, Py_LT);
|
||||
Py_DECREF(a);
|
||||
Py_DECREF(b);
|
||||
if (cmp < 0)
|
||||
return -1;
|
||||
childpos += ((unsigned)cmp ^ 1); /* increment when cmp==0 */
|
||||
|
@ -264,7 +271,10 @@ _heapq_heappushpop_impl(PyObject *module, PyObject *heap, PyObject *item)
|
|||
return item;
|
||||
}
|
||||
|
||||
cmp = PyObject_RichCompareBool(PyList_GET_ITEM(heap, 0), item, Py_LT);
|
||||
PyObject* top = PyList_GET_ITEM(heap, 0);
|
||||
Py_INCREF(top);
|
||||
cmp = PyObject_RichCompareBool(top, item, Py_LT);
|
||||
Py_DECREF(top);
|
||||
if (cmp < 0)
|
||||
return NULL;
|
||||
if (cmp == 0) {
|
||||
|
@ -420,7 +430,11 @@ siftdown_max(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos)
|
|||
while (pos > startpos) {
|
||||
parentpos = (pos - 1) >> 1;
|
||||
parent = arr[parentpos];
|
||||
Py_INCREF(parent);
|
||||
Py_INCREF(newitem);
|
||||
cmp = PyObject_RichCompareBool(parent, newitem, Py_LT);
|
||||
Py_DECREF(parent);
|
||||
Py_DECREF(newitem);
|
||||
if (cmp < 0)
|
||||
return -1;
|
||||
if (size != PyList_GET_SIZE(heap)) {
|
||||
|
@ -462,10 +476,13 @@ siftup_max(PyListObject *heap, Py_ssize_t pos)
|
|||
/* Set childpos to index of smaller child. */
|
||||
childpos = 2*pos + 1; /* leftmost child position */
|
||||
if (childpos + 1 < endpos) {
|
||||
cmp = PyObject_RichCompareBool(
|
||||
arr[childpos + 1],
|
||||
arr[childpos],
|
||||
Py_LT);
|
||||
PyObject* a = arr[childpos + 1];
|
||||
PyObject* b = arr[childpos];
|
||||
Py_INCREF(a);
|
||||
Py_INCREF(b);
|
||||
cmp = PyObject_RichCompareBool(a, b, Py_LT);
|
||||
Py_DECREF(a);
|
||||
Py_DECREF(b);
|
||||
if (cmp < 0)
|
||||
return -1;
|
||||
childpos += ((unsigned)cmp ^ 1); /* increment when cmp==0 */
|
||||
|
|
Loading…
Reference in New Issue