From 7c2bb5bc5765491a80bdca43545c5f41481cbf08 Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Sat, 3 May 2003 05:59:48 +0000 Subject: [PATCH] * Added a substantial number of edge case and argument tests for the itertoolsmodule. * Taught itertools.repeat(obj, n) to treat negative repeat counts as zero. This behavior matches that for sequences and prevents infinite loops. --- Lib/test/test_itertools.py | 91 +++++++++++++++++++++++++++++++------- Modules/itertoolsmodule.c | 3 ++ 2 files changed, 79 insertions(+), 15 deletions(-) diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index e03ce8c18d7..7a4fa209ab9 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -2,63 +2,119 @@ import unittest from test import test_support from itertools import * import sys +import operator + +def onearg(x): + 'Test function of one argument' + return x + +def errfunc(*args): + 'Test function that raises an error' + raise ValueError + +def gen3(): + 'Non-restartable source sequence' + for i in (0, 1, 2): + yield i + +def isEven(x): + 'Test predicate' + return x%2==0 + +class StopNow: + 'Class emulating an empty iterable.' + def __iter__(self): + return self + def next(self): + raise StopIteration class TestBasicOps(unittest.TestCase): def test_chain(self): self.assertEqual(list(chain('abc', 'def')), list('abcdef')) + self.assertEqual(list(chain('abc')), list('abc')) + self.assertEqual(list(chain('')), []) + self.assertRaises(TypeError, chain, 2, 3) def test_count(self): self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)]) self.assertEqual(zip('abc',count(3)), [('a', 3), ('b', 4), ('c', 5)]) self.assertRaises(TypeError, count, 2, 3) + self.assertRaises(TypeError, count, 'a') + c = count(sys.maxint-2) # verify that rollover doesn't crash + c.next(); c.next(); c.next(); c.next(); c.next() def test_cycle(self): self.assertEqual(list(islice(cycle('abc'),10)), list('abcabcabca')) self.assertEqual(list(cycle('')), []) + self.assertRaises(TypeError, cycle) + self.assertRaises(TypeError, cycle, 5) + self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0]) def test_ifilter(self): - def isEven(x): - return x%2==0 self.assertEqual(list(ifilter(isEven, range(6))), [0,2,4]) self.assertEqual(list(ifilter(None, [0,1,0,2,0])), [1,2]) self.assertRaises(TypeError, ifilter) - self.assertRaises(TypeError, ifilter, 3) + self.assertRaises(TypeError, ifilter, lambda x:x) + self.assertRaises(TypeError, ifilter, lambda x:x, range(6), 7) self.assertRaises(TypeError, ifilter, isEven, 3) + self.assertRaises(TypeError, ifilter(range(6), range(6)).next) def test_ifilterfalse(self): - def isEven(x): - return x%2==0 self.assertEqual(list(ifilterfalse(isEven, range(6))), [1,3,5]) self.assertEqual(list(ifilterfalse(None, [0,1,0,2,0])), [0,0,0]) self.assertRaises(TypeError, ifilterfalse) - self.assertRaises(TypeError, ifilterfalse, 3) + self.assertRaises(TypeError, ifilterfalse, lambda x:x) + self.assertRaises(TypeError, ifilterfalse, lambda x:x, range(6), 7) self.assertRaises(TypeError, ifilterfalse, isEven, 3) + self.assertRaises(TypeError, ifilterfalse(range(6), range(6)).next) def test_izip(self): ans = [(x,y) for x, y in izip('abc',count())] self.assertEqual(ans, [('a', 0), ('b', 1), ('c', 2)]) + self.assertEqual(list(izip('abc', range(6))), zip('abc', range(6))) + self.assertEqual(list(izip('abcdef', range(3))), zip('abcdef', range(3))) + self.assertEqual(list(izip('abcdef')), zip('abcdef')) self.assertRaises(TypeError, izip) + self.assertRaises(TypeError, izip, 3) + self.assertRaises(TypeError, izip, range(3), 3) + # Check tuple re-use (implementation detail) + self.assertEqual([tuple(list(pair)) for pair in izip('abc', 'def')], + zip('abc', 'def')) + ids = map(id, izip('abc', 'def')) + self.assertEqual(min(ids), max(ids)) def test_repeat(self): self.assertEqual(zip(xrange(3),repeat('a')), [(0, 'a'), (1, 'a'), (2, 'a')]) self.assertEqual(list(repeat('a', 3)), ['a', 'a', 'a']) + self.assertEqual(list(repeat('a', 0)), []) + self.assertEqual(list(repeat('a', -3)), []) self.assertRaises(TypeError, repeat) + self.assertRaises(TypeError, repeat, None, 3, 4) + self.assertRaises(TypeError, repeat, None, 'a') def test_imap(self): - import operator self.assertEqual(list(imap(operator.pow, range(3), range(1,7))), [0**1, 1**2, 2**3]) self.assertEqual(list(imap(None, 'abc', range(5))), [('a',0),('b',1),('c',2)]) + self.assertEqual(list(imap(operator.pow, [])), []) self.assertRaises(TypeError, imap) self.assertRaises(TypeError, imap, operator.neg) + self.assertRaises(TypeError, imap(10, range(5)).next) + self.assertRaises(ValueError, imap(errfunc, [4], [5]).next) + self.assertRaises(TypeError, imap(onearg, [4], [5]).next) def test_starmap(self): - import operator self.assertEqual(list(starmap(operator.pow, zip(range(3), range(1,7)))), [0**1, 1**2, 2**3]) + self.assertEqual(list(starmap(operator.pow, [])), []) self.assertRaises(TypeError, list, starmap(operator.pow, [[4,5]])) + self.assertRaises(TypeError, starmap) + self.assertRaises(TypeError, starmap, operator.pow, [(4,5)], 'extra') + self.assertRaises(TypeError, starmap(10, [(4,5)]).next) + self.assertRaises(ValueError, starmap(errfunc, [(4,5)]).next) + self.assertRaises(TypeError, starmap(onearg, [(4,5)]).next) def test_islice(self): for args in [ # islice(args) should agree with range(args) @@ -100,20 +156,25 @@ class TestBasicOps(unittest.TestCase): data = [1, 3, 5, 20, 2, 4, 6, 8] underten = lambda x: x<10 self.assertEqual(list(takewhile(underten, data)), [1, 3, 5]) + self.assertEqual(list(takewhile(underten, [])), []) + self.assertRaises(TypeError, takewhile) + self.assertRaises(TypeError, takewhile, operator.pow) + self.assertRaises(TypeError, takewhile, operator.pow, [(4,5)], 'extra') + self.assertRaises(TypeError, takewhile(10, [(4,5)]).next) + self.assertRaises(ValueError, takewhile(errfunc, [(4,5)]).next) def test_dropwhile(self): data = [1, 3, 5, 20, 2, 4, 6, 8] underten = lambda x: x<10 self.assertEqual(list(dropwhile(underten, data)), [20, 2, 4, 6, 8]) + self.assertEqual(list(dropwhile(underten, [])), []) + self.assertRaises(TypeError, dropwhile) + self.assertRaises(TypeError, dropwhile, operator.pow) + self.assertRaises(TypeError, dropwhile, operator.pow, [(4,5)], 'extra') + self.assertRaises(TypeError, dropwhile(10, [(4,5)]).next) + self.assertRaises(ValueError, dropwhile(errfunc, [(4,5)]).next) def test_StopIteration(self): - class StopNow: - """Test support class . Emulates an empty iterable.""" - def __iter__(self): - return self - def next(self): - raise StopIteration - for f in (chain, cycle, izip): self.assertRaises(StopIteration, f([]).next) self.assertRaises(StopIteration, f(StopNow()).next) diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index 2d496b5c582..b52f34923ad 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -1691,6 +1691,9 @@ repeat_new(PyTypeObject *type, PyObject *args, PyObject *kwds) if (!PyArg_ParseTuple(args, "O|l:repeat", &element, &cnt)) return NULL; + if (PyTuple_Size(args) == 2 && cnt < 0) + cnt = 0; + ro = (repeatobject *)type->tp_alloc(type, 0); if (ro == NULL) return NULL;