More tests

* Test with infinite inputs (using take() on the output)
* Test whether GC can find and eliminate cycles.
This commit is contained in:
Raymond Hettinger 2003-06-29 20:36:23 +00:00
parent 5728815e7b
commit 0242070d04
1 changed files with 59 additions and 3 deletions

View File

@ -32,24 +32,29 @@ class StopNow:
def next(self): def next(self):
raise StopIteration raise StopIteration
def take(n, seq):
'Convenience function for partially consuming a long of infinite iterable'
return list(islice(seq, n))
class TestBasicOps(unittest.TestCase): class TestBasicOps(unittest.TestCase):
def test_chain(self): def test_chain(self):
self.assertEqual(list(chain('abc', 'def')), list('abcdef')) self.assertEqual(list(chain('abc', 'def')), list('abcdef'))
self.assertEqual(list(chain('abc')), list('abc')) self.assertEqual(list(chain('abc')), list('abc'))
self.assertEqual(list(chain('')), []) self.assertEqual(list(chain('')), [])
self.assertEqual(take(4, chain('abc', 'def')), list('abcd'))
self.assertRaises(TypeError, chain, 2, 3) self.assertRaises(TypeError, chain, 2, 3)
def test_count(self): def test_count(self):
self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)]) self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)])
self.assertEqual(zip('abc',count(3)), [('a', 3), ('b', 4), ('c', 5)]) self.assertEqual(zip('abc',count(3)), [('a', 3), ('b', 4), ('c', 5)])
self.assertEqual(take(2, zip('abc',count(3))), [('a', 3), ('b', 4)])
self.assertRaises(TypeError, count, 2, 3) self.assertRaises(TypeError, count, 2, 3)
self.assertRaises(TypeError, count, 'a') self.assertRaises(TypeError, count, 'a')
c = count(sys.maxint-2) # verify that rollover doesn't crash c = count(sys.maxint-2) # verify that rollover doesn't crash
c.next(); c.next(); c.next(); c.next(); c.next() c.next(); c.next(); c.next(); c.next(); c.next()
def test_cycle(self): def test_cycle(self):
self.assertEqual(list(islice(cycle('abc'),10)), list('abcabcabca')) self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
self.assertEqual(list(cycle('')), []) self.assertEqual(list(cycle('')), [])
self.assertRaises(TypeError, cycle) self.assertRaises(TypeError, cycle)
self.assertRaises(TypeError, cycle, 5) self.assertRaises(TypeError, cycle, 5)
@ -58,6 +63,7 @@ class TestBasicOps(unittest.TestCase):
def test_ifilter(self): def test_ifilter(self):
self.assertEqual(list(ifilter(isEven, range(6))), [0,2,4]) self.assertEqual(list(ifilter(isEven, range(6))), [0,2,4])
self.assertEqual(list(ifilter(None, [0,1,0,2,0])), [1,2]) self.assertEqual(list(ifilter(None, [0,1,0,2,0])), [1,2])
self.assertEqual(take(4, ifilter(isEven, count())), [0,2,4,6])
self.assertRaises(TypeError, ifilter) self.assertRaises(TypeError, ifilter)
self.assertRaises(TypeError, ifilter, lambda x:x) self.assertRaises(TypeError, ifilter, lambda x:x)
self.assertRaises(TypeError, ifilter, lambda x:x, range(6), 7) self.assertRaises(TypeError, ifilter, lambda x:x, range(6), 7)
@ -67,6 +73,7 @@ class TestBasicOps(unittest.TestCase):
def test_ifilterfalse(self): def test_ifilterfalse(self):
self.assertEqual(list(ifilterfalse(isEven, range(6))), [1,3,5]) self.assertEqual(list(ifilterfalse(isEven, range(6))), [1,3,5])
self.assertEqual(list(ifilterfalse(None, [0,1,0,2,0])), [0,0,0]) self.assertEqual(list(ifilterfalse(None, [0,1,0,2,0])), [0,0,0])
self.assertEqual(take(4, ifilterfalse(isEven, count())), [1,3,5,7])
self.assertRaises(TypeError, ifilterfalse) self.assertRaises(TypeError, ifilterfalse)
self.assertRaises(TypeError, ifilterfalse, lambda x:x) self.assertRaises(TypeError, ifilterfalse, lambda x:x)
self.assertRaises(TypeError, ifilterfalse, lambda x:x, range(6), 7) self.assertRaises(TypeError, ifilterfalse, lambda x:x, range(6), 7)
@ -78,6 +85,7 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(ans, [('a', 0), ('b', 1), ('c', 2)]) self.assertEqual(ans, [('a', 0), ('b', 1), ('c', 2)])
self.assertEqual(list(izip('abc', range(6))), zip('abc', range(6))) self.assertEqual(list(izip('abc', range(6))), zip('abc', range(6)))
self.assertEqual(list(izip('abcdef', range(3))), zip('abcdef', range(3))) self.assertEqual(list(izip('abcdef', range(3))), zip('abcdef', range(3)))
self.assertEqual(take(3,izip('abcdef', count())), zip('abcdef', range(3)))
self.assertEqual(list(izip('abcdef')), zip('abcdef')) self.assertEqual(list(izip('abcdef')), zip('abcdef'))
self.assertRaises(TypeError, izip) self.assertRaises(TypeError, izip)
self.assertRaises(TypeError, izip, 3) self.assertRaises(TypeError, izip, 3)
@ -96,6 +104,7 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(zip(xrange(3),repeat('a')), self.assertEqual(zip(xrange(3),repeat('a')),
[(0, 'a'), (1, 'a'), (2, 'a')]) [(0, 'a'), (1, 'a'), (2, 'a')])
self.assertEqual(list(repeat('a', 3)), ['a', 'a', 'a']) self.assertEqual(list(repeat('a', 3)), ['a', 'a', 'a'])
self.assertEqual(take(3, repeat('a')), ['a', 'a', 'a'])
self.assertEqual(list(repeat('a', 0)), []) self.assertEqual(list(repeat('a', 0)), [])
self.assertEqual(list(repeat('a', -3)), []) self.assertEqual(list(repeat('a', -3)), [])
self.assertRaises(TypeError, repeat) self.assertRaises(TypeError, repeat)
@ -107,6 +116,10 @@ class TestBasicOps(unittest.TestCase):
[0**1, 1**2, 2**3]) [0**1, 1**2, 2**3])
self.assertEqual(list(imap(None, 'abc', range(5))), self.assertEqual(list(imap(None, 'abc', range(5))),
[('a',0),('b',1),('c',2)]) [('a',0),('b',1),('c',2)])
self.assertEqual(list(imap(None, 'abc', count())),
[('a',0),('b',1),('c',2)])
self.assertEqual(take(2,imap(None, 'abc', count())),
[('a',0),('b',1)])
self.assertEqual(list(imap(operator.pow, [])), []) self.assertEqual(list(imap(operator.pow, [])), [])
self.assertRaises(TypeError, imap) self.assertRaises(TypeError, imap)
self.assertRaises(TypeError, imap, operator.neg) self.assertRaises(TypeError, imap, operator.neg)
@ -117,6 +130,8 @@ class TestBasicOps(unittest.TestCase):
def test_starmap(self): def test_starmap(self):
self.assertEqual(list(starmap(operator.pow, zip(range(3), range(1,7)))), self.assertEqual(list(starmap(operator.pow, zip(range(3), range(1,7)))),
[0**1, 1**2, 2**3]) [0**1, 1**2, 2**3])
self.assertEqual(take(3, starmap(operator.pow, izip(count(), count(1)))),
[0**1, 1**2, 2**3])
self.assertEqual(list(starmap(operator.pow, [])), []) self.assertEqual(list(starmap(operator.pow, [])), [])
self.assertRaises(TypeError, list, starmap(operator.pow, [[4,5]])) self.assertRaises(TypeError, list, starmap(operator.pow, [[4,5]]))
self.assertRaises(TypeError, starmap) self.assertRaises(TypeError, starmap)
@ -273,6 +288,45 @@ def L(seqn):
'Test multiple tiers of iterators' 'Test multiple tiers of iterators'
return chain(imap(lambda x:x, R(Ig(G(seqn))))) return chain(imap(lambda x:x, R(Ig(G(seqn)))))
class TestGC(unittest.TestCase):
def makecycle(self, iterator, container):
container.append(iterator)
iterator.next()
del container, iterator
def test_chain(self):
a = []
self.makecycle(chain(a), a)
def test_cycle(self):
a = []
self.makecycle(cycle([a]*2), a)
def test_ifilter(self):
a = []
self.makecycle(ifilter(lambda x:True, [a]*2), a)
def test_ifilterfalse(self):
a = []
self.makecycle(ifilterfalse(lambda x:False, a), a)
def test_izip(self):
a = []
self.makecycle(izip([a]*2, [a]*3), a)
def test_imap(self):
a = []
self.makecycle(imap(lambda x:x, [a]*2), a)
def test_islice(self):
a = []
self.makecycle(islice([a]*2, None), a)
def test_starmap(self):
a = []
self.makecycle(starmap(lambda *t: t, [(a,a)]*2), a)
class TestVariousIteratorArgs(unittest.TestCase): class TestVariousIteratorArgs(unittest.TestCase):
@ -420,7 +474,7 @@ Samuele
... return not nth(ifilterfalse(pred, seq), 0) ... return not nth(ifilterfalse(pred, seq), 0)
>>> def some(pred, seq): >>> def some(pred, seq):
... "Returns True if pred(x) is True at least one element in the iterable" ... "Returns True if pred(x) is True for at least one element in the iterable"
... return bool(nth(ifilter(pred, seq), 0)) ... return bool(nth(ifilter(pred, seq), 0))
>>> def no(pred, seq): >>> def no(pred, seq):
@ -505,14 +559,16 @@ False
__test__ = {'libreftest' : libreftest} __test__ = {'libreftest' : libreftest}
def test_main(verbose=None): def test_main(verbose=None):
test_classes = (TestBasicOps, TestVariousIteratorArgs) test_classes = (TestBasicOps, TestVariousIteratorArgs, TestGC)
test_support.run_unittest(*test_classes) test_support.run_unittest(*test_classes)
# verify reference counting # verify reference counting
if verbose and hasattr(sys, "gettotalrefcount"): if verbose and hasattr(sys, "gettotalrefcount"):
import gc
counts = [None] * 5 counts = [None] * 5
for i in xrange(len(counts)): for i in xrange(len(counts)):
test_support.run_unittest(*test_classes) test_support.run_unittest(*test_classes)
gc.collect()
counts[i] = sys.gettotalrefcount() counts[i] = sys.gettotalrefcount()
print counts print counts