From 855d9a985b861cc2c475f4020c120a25548b4c98 Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Tue, 28 Sep 2004 00:03:54 +0000 Subject: [PATCH] Plug a leak and beef-up test coverage. --- Lib/test/test_heapq.py | 149 +++++++++++++++++++++++++++++++++++++++++ Modules/_heapqmodule.c | 16 +++-- 2 files changed, 161 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py index d7098d97629..7848e4e98a0 100644 --- a/Lib/test/test_heapq.py +++ b/Lib/test/test_heapq.py @@ -110,8 +110,157 @@ class TestHeap(unittest.TestCase): for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): self.assertEqual(nlargest(n, data), sorted(data, reverse=True)[:n]) + +#============================================================================== + +class LenOnly: + "Dummy sequence class defining __len__ but not __getitem__." + def __len__(self): + return 10 + +class GetOnly: + "Dummy sequence class defining __getitem__ but not __len__." + def __getitem__(self, ndx): + return 10 + +class CmpErr: + "Dummy element that always raises an error during comparison" + def __cmp__(self, other): + raise ZeroDivisionError + +def R(seqn): + 'Regular generator' + for i in seqn: + yield i + +class G: + 'Sequence using __getitem__' + def __init__(self, seqn): + self.seqn = seqn + def __getitem__(self, i): + return self.seqn[i] + +class I: + 'Sequence using iterator protocol' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + return self + def next(self): + if self.i >= len(self.seqn): raise StopIteration + v = self.seqn[self.i] + self.i += 1 + return v + +class Ig: + 'Sequence using iterator protocol defined with a generator' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + for val in self.seqn: + yield val + +class X: + 'Missing __getitem__ and __iter__' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def next(self): + if self.i >= len(self.seqn): raise StopIteration + v = self.seqn[self.i] + self.i += 1 + return v + +class N: + 'Iterator missing next()' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + return self + +class E: + 'Test propagation of exceptions' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + return self + def next(self): + 3 // 0 + +class S: + 'Test immediate stop' + def __init__(self, seqn): + pass + def __iter__(self): + return self + def next(self): + raise StopIteration + +from itertools import chain, imap +def L(seqn): + 'Test multiple tiers of iterators' + return chain(imap(lambda x:x, R(Ig(G(seqn))))) + +class TestErrorHandling(unittest.TestCase): + + def test_non_sequence(self): + for f in (heapify, heappop): + self.assertRaises(TypeError, f, 10) + for f in (heappush, heapreplace, nlargest, nsmallest): + self.assertRaises(TypeError, f, 10, 10) + + def test_len_only(self): + for f in (heapify, heappop): + self.assertRaises(TypeError, f, LenOnly()) + for f in (heappush, heapreplace): + self.assertRaises(TypeError, f, LenOnly(), 10) + for f in (nlargest, nsmallest): + self.assertRaises(TypeError, f, 2, LenOnly()) + + def test_get_only(self): + for f in (heapify, heappop): + self.assertRaises(TypeError, f, GetOnly()) + for f in (heappush, heapreplace): + self.assertRaises(TypeError, f, GetOnly(), 10) + for f in (nlargest, nsmallest): + self.assertRaises(TypeError, f, 2, GetOnly()) + + def test_get_only(self): + seq = [CmpErr(), CmpErr(), CmpErr()] + for f in (heapify, heappop): + self.assertRaises(ZeroDivisionError, f, seq) + for f in (heappush, heapreplace): + self.assertRaises(ZeroDivisionError, f, seq, 10) + for f in (nlargest, nsmallest): + self.assertRaises(ZeroDivisionError, f, 2, seq) + + def test_arg_parsing(self): + for f in (heapify, heappop, heappush, heapreplace, nlargest, nsmallest): + self.assertRaises(TypeError, f, 10) + + def test_iterable_args(self): + for f in (nlargest, nsmallest): + for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)): + for g in (G, I, Ig, L, R): + self.assertEqual(f(2, g(s)), f(2,s)) + self.assertEqual(f(2, S(s)), []) + self.assertRaises(TypeError, f, 2, X(s)) + self.assertRaises(TypeError, f, 2, N(s)) + self.assertRaises(ZeroDivisionError, f, 2, E(s)) + +#============================================================================== + + def test_main(verbose=None): + from types import BuiltinFunctionType + test_classes = [TestHeap] + if isinstance(heapify, BuiltinFunctionType): + test_classes.append(TestErrorHandling) test_support.run_unittest(*test_classes) # verify reference counting diff --git a/Modules/_heapqmodule.c b/Modules/_heapqmodule.c index 192e843690a..5a78c453e19 100644 --- a/Modules/_heapqmodule.c +++ b/Modules/_heapqmodule.c @@ -28,8 +28,10 @@ _siftdown(PyListObject *heap, int startpos, int pos) parentpos = (pos - 1) >> 1; parent = PyList_GET_ITEM(heap, parentpos); cmp = PyObject_RichCompareBool(parent, newitem, Py_LE); - if (cmp == -1) + if (cmp == -1) { + Py_DECREF(newitem); return -1; + } if (cmp == 1) break; Py_INCREF(parent); @@ -69,8 +71,10 @@ _siftup(PyListObject *heap, int pos) PyList_GET_ITEM(heap, rightpos), PyList_GET_ITEM(heap, childpos), Py_LE); - if (cmp == -1) + if (cmp == -1) { + Py_DECREF(newitem); return -1; + } if (cmp == 1) childpos = rightpos; } @@ -315,8 +319,10 @@ _siftdownmax(PyListObject *heap, int startpos, int pos) parentpos = (pos - 1) >> 1; parent = PyList_GET_ITEM(heap, parentpos); cmp = PyObject_RichCompareBool(newitem, parent, Py_LE); - if (cmp == -1) + if (cmp == -1) { + Py_DECREF(newitem); return -1; + } if (cmp == 1) break; Py_INCREF(parent); @@ -356,8 +362,10 @@ _siftupmax(PyListObject *heap, int pos) PyList_GET_ITEM(heap, childpos), PyList_GET_ITEM(heap, rightpos), Py_LE); - if (cmp == -1) + if (cmp == -1) { + Py_DECREF(newitem); return -1; + } if (cmp == 1) childpos = rightpos; }