diff --git a/Lib/heapq.py b/Lib/heapq.py index a44d1beb047..74f7310a2cb 100644 --- a/Lib/heapq.py +++ b/Lib/heapq.py @@ -133,6 +133,11 @@ from itertools import islice, repeat, count, imap, izip, tee, chain from operator import itemgetter import bisect +def cmp_lt(x, y): + # Use __lt__ if available; otherwise, try __le__. + # In Py3.x, only __lt__ will be called. + return (x < y) if hasattr(x, '__lt__') else (not y <= x) + def heappush(heap, item): """Push item onto heap, maintaining the heap invariant.""" heap.append(item) @@ -167,7 +172,7 @@ def heapreplace(heap, item): def heappushpop(heap, item): """Fast version of a heappush followed by a heappop.""" - if heap and heap[0] < item: + if heap and cmp_lt(heap[0], item): item, heap[0] = heap[0], item _siftup(heap, 0) return item @@ -215,11 +220,10 @@ def nsmallest(n, iterable): pop = result.pop los = result[-1] # los --> Largest of the nsmallest for elem in it: - if los <= elem: - continue - insort(result, elem) - pop() - los = result[-1] + if cmp_lt(elem, los): + insort(result, elem) + pop() + los = result[-1] return result # An alternative approach manifests the whole iterable in memory but # saves comparisons by heapifying all at once. Also, saves time @@ -240,7 +244,7 @@ def _siftdown(heap, startpos, pos): while pos > startpos: parentpos = (pos - 1) >> 1 parent = heap[parentpos] - if newitem < parent: + if cmp_lt(newitem, parent): heap[pos] = parent pos = parentpos continue @@ -295,7 +299,7 @@ def _siftup(heap, pos): while childpos < endpos: # Set childpos to index of smaller child. rightpos = childpos + 1 - if rightpos < endpos and not heap[childpos] < heap[rightpos]: + if rightpos < endpos and not cmp_lt(heap[childpos], heap[rightpos]): childpos = rightpos # Move the smaller child up. heap[pos] = heap[childpos] diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py index e4d2cc8b887..d5d8c1a179e 100644 --- a/Lib/test/test_heapq.py +++ b/Lib/test/test_heapq.py @@ -209,12 +209,6 @@ class TestHeapC(TestHeap): self.assertEqual(hsort(data, LT), target) self.assertEqual(hsort(data, LE), target) - # As an early adopter, we sanity check the - # test_support.import_fresh_module utility function - def test_accelerated(self): - self.assertTrue(sys.modules['heapq'] is self.module) - self.assertFalse(hasattr(self.module.heapify, 'func_code')) - #============================================================================== @@ -316,16 +310,16 @@ class TestErrorHandling(unittest.TestCase): def test_non_sequence(self): for f in (self.module.heapify, self.module.heappop): - self.assertRaises(TypeError, f, 10) + self.assertRaises((TypeError, AttributeError), f, 10) for f in (self.module.heappush, self.module.heapreplace, self.module.nlargest, self.module.nsmallest): - self.assertRaises(TypeError, f, 10, 10) + self.assertRaises((TypeError, AttributeError), f, 10, 10) def test_len_only(self): for f in (self.module.heapify, self.module.heappop): - self.assertRaises(TypeError, f, LenOnly()) + self.assertRaises((TypeError, AttributeError), f, LenOnly()) for f in (self.module.heappush, self.module.heapreplace): - self.assertRaises(TypeError, f, LenOnly(), 10) + self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10) for f in (self.module.nlargest, self.module.nsmallest): self.assertRaises(TypeError, f, 2, LenOnly()) @@ -342,7 +336,7 @@ class TestErrorHandling(unittest.TestCase): for f in (self.module.heapify, self.module.heappop, self.module.heappush, self.module.heapreplace, self.module.nlargest, self.module.nsmallest): - self.assertRaises(TypeError, f, 10) + self.assertRaises((TypeError, AttributeError), f, 10) def test_iterable_args(self): for f in (self.module.nlargest, self.module.nsmallest):