Implement heapq in terms of less-than (to match list.sort()).

This commit is contained in:
Raymond Hettinger 2008-05-31 03:24:31 +00:00
parent adff65bc3e
commit 6d7702ecd1
3 changed files with 36 additions and 20 deletions

View File

@ -167,7 +167,7 @@ def heapreplace(heap, item):
def heappushpop(heap, item):
"""Fast version of a heappush followed by a heappop."""
if heap and item > heap[0]:
if heap and heap[0] < item:
item, heap[0] = heap[0], item
_siftup(heap, 0)
return item
@ -240,10 +240,11 @@ def _siftdown(heap, startpos, pos):
while pos > startpos:
parentpos = (pos - 1) >> 1
parent = heap[parentpos]
if parent <= newitem:
break
heap[pos] = parent
pos = parentpos
if newitem < parent:
heap[pos] = parent
pos = parentpos
continue
break
heap[pos] = newitem
# The child indices of heap index pos are already heaps, and we want to make
@ -294,7 +295,7 @@ def _siftup(heap, pos):
while childpos < endpos:
# Set childpos to index of smaller child.
rightpos = childpos + 1
if rightpos < endpos and heap[rightpos] <= heap[childpos]:
if rightpos < endpos and not heap[childpos] < heap[rightpos]:
childpos = rightpos
# Move the smaller child up.
heap[pos] = heap[childpos]

View File

@ -36,6 +36,9 @@ Core and Builtins
Extension Modules
-----------------
- The heapq module does comparisons using LT instead of LE. This
makes its implementation match that used by list.sort().
- Issue #2819: add full-precision summation function to math module,
based on Hettinger's ASPN Python Cookbook recipe.

View File

@ -28,12 +28,12 @@ _siftdown(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos)
while (pos > startpos){
parentpos = (pos - 1) >> 1;
parent = PyList_GET_ITEM(heap, parentpos);
cmp = PyObject_RichCompareBool(parent, newitem, Py_LE);
cmp = PyObject_RichCompareBool(newitem, parent, Py_LT);
if (cmp == -1) {
Py_DECREF(newitem);
return -1;
}
if (cmp == 1)
if (cmp == 0)
break;
Py_INCREF(parent);
Py_DECREF(PyList_GET_ITEM(heap, pos));
@ -69,14 +69,14 @@ _siftup(PyListObject *heap, Py_ssize_t pos)
rightpos = childpos + 1;
if (rightpos < endpos) {
cmp = PyObject_RichCompareBool(
PyList_GET_ITEM(heap, rightpos),
PyList_GET_ITEM(heap, childpos),
Py_LE);
PyList_GET_ITEM(heap, rightpos),
Py_LT);
if (cmp == -1) {
Py_DECREF(newitem);
return -1;
}
if (cmp == 1)
if (cmp == 0)
childpos = rightpos;
}
/* Move the smaller child up. */
@ -214,10 +214,10 @@ heappushpop(PyObject *self, PyObject *args)
return item;
}
cmp = PyObject_RichCompareBool(item, PyList_GET_ITEM(heap, 0), Py_LE);
cmp = PyObject_RichCompareBool(PyList_GET_ITEM(heap, 0), item, Py_LT);
if (cmp == -1)
return NULL;
if (cmp == 1) {
if (cmp == 0) {
Py_INCREF(item);
return item;
}
@ -270,6 +270,7 @@ nlargest(PyObject *self, PyObject *args)
{
PyObject *heap=NULL, *elem, *iterable, *sol, *it, *oldelem;
Py_ssize_t i, n;
int cmp;
if (!PyArg_ParseTuple(args, "nO:nlargest", &n, &iterable))
return NULL;
@ -312,7 +313,12 @@ nlargest(PyObject *self, PyObject *args)
else
goto sortit;
}
if (PyObject_RichCompareBool(elem, sol, Py_LE)) {
cmp = PyObject_RichCompareBool(sol, elem, Py_LT);
if (cmp == -1) {
Py_DECREF(elem);
goto fail;
}
if (cmp == 0) {
Py_DECREF(elem);
continue;
}
@ -362,12 +368,12 @@ _siftdownmax(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos)
while (pos > startpos){
parentpos = (pos - 1) >> 1;
parent = PyList_GET_ITEM(heap, parentpos);
cmp = PyObject_RichCompareBool(newitem, parent, Py_LE);
cmp = PyObject_RichCompareBool(parent, newitem, Py_LT);
if (cmp == -1) {
Py_DECREF(newitem);
return -1;
}
if (cmp == 1)
if (cmp == 0)
break;
Py_INCREF(parent);
Py_DECREF(PyList_GET_ITEM(heap, pos));
@ -403,14 +409,14 @@ _siftupmax(PyListObject *heap, Py_ssize_t pos)
rightpos = childpos + 1;
if (rightpos < endpos) {
cmp = PyObject_RichCompareBool(
PyList_GET_ITEM(heap, childpos),
PyList_GET_ITEM(heap, rightpos),
Py_LE);
PyList_GET_ITEM(heap, childpos),
Py_LT);
if (cmp == -1) {
Py_DECREF(newitem);
return -1;
}
if (cmp == 1)
if (cmp == 0)
childpos = rightpos;
}
/* Move the smaller child up. */
@ -434,6 +440,7 @@ nsmallest(PyObject *self, PyObject *args)
{
PyObject *heap=NULL, *elem, *iterable, *los, *it, *oldelem;
Py_ssize_t i, n;
int cmp;
if (!PyArg_ParseTuple(args, "nO:nsmallest", &n, &iterable))
return NULL;
@ -477,7 +484,12 @@ nsmallest(PyObject *self, PyObject *args)
else
goto sortit;
}
if (PyObject_RichCompareBool(los, elem, Py_LE)) {
cmp = PyObject_RichCompareBool(elem, los, Py_LT);
if (cmp == -1) {
Py_DECREF(elem);
goto fail;
}
if (cmp == 0) {
Py_DECREF(elem);
continue;
}