From 38fb9bee6c25f4d6742ddef94405f7bd2ca65ba3 Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Fri, 7 Mar 2008 01:33:20 +0000 Subject: [PATCH] Tweak recipes and tests --- Doc/library/itertools.rst | 12 ++++++------ Lib/test/test_itertools.py | 37 ++++++++++++++++++++++++++++++------- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst index 9ed0c54de1d..f546fe16ee2 100644 --- a/Doc/library/itertools.rst +++ b/Doc/library/itertools.rst @@ -662,15 +662,15 @@ which incur interpreter overhead. :: def pairwise(iterable): "s -> (s0,s1), (s1,s2), (s2, s3), ..." a, b = tee(iterable) - try: - b.next() - except StopIteration: - pass + for elem in b: + break return izip(a, b) - def grouper(n, iterable, padvalue=None): + def grouper(n, iterable, fillvalue=None): "grouper(3, 'abcdefg', 'x') --> ('a','b','c'), ('d','e','f'), ('g','x','x')" - return izip(*[chain(iterable, repeat(padvalue, n-1))]*n) + args = [iter(iterable)] * n + kwds = dict(fillvalue=fillvalue) + return izip_longest(*args, **kwds) def roundrobin(*iterables): "roundrobin('abc', 'd', 'ef') --> 'a', 'd', 'e', 'b', 'f', 'c'" diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 3bd2255ab9a..696fdebf1ef 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -410,6 +410,28 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(len(list(product(*[range(7)]*6))), 7**6) self.assertRaises(TypeError, product, range(6), None) + def product1(*args, **kwds): + pools = map(tuple, args) * kwds.get('repeat', 1) + n = len(pools) + if n == 0: + yield () + return + if any(len(pool) == 0 for pool in pools): + return + indices = [0] * n + yield tuple(pool[i] for pool, i in zip(pools, indices)) + while 1: + for i in reversed(range(n)): # right to left + if indices[i] == len(pools[i]) - 1: + continue + indices[i] += 1 + for j in range(i+1, n): + indices[j] = 0 + yield tuple(pool[i] for pool, i in zip(pools, indices)) + break + else: + return + def product2(*args, **kwds): 'Pure python version used in docs' pools = map(tuple, args) * kwds.get('repeat', 1) @@ -425,6 +447,7 @@ class TestBasicOps(unittest.TestCase): args = [random.choice(argtypes) for j in range(random.randrange(5))] expected_len = prod(map(len, args)) self.assertEqual(len(list(product(*args))), expected_len) + self.assertEqual(list(product(*args)), list(product1(*args))) self.assertEqual(list(product(*args)), list(product2(*args))) args = map(iter, args) self.assertEqual(len(list(product(*args))), expected_len) @@ -1213,7 +1236,7 @@ Samuele ... return sum(imap(operator.mul, vec1, vec2)) >>> def flatten(listOfLists): -... return list(chain(*listOfLists)) +... return list(chain.from_iterable(listOfLists)) >>> def repeatfunc(func, times=None, *args): ... "Repeat calls to func with specified arguments." @@ -1226,15 +1249,15 @@ Samuele >>> def pairwise(iterable): ... "s -> (s0,s1), (s1,s2), (s2, s3), ..." ... a, b = tee(iterable) -... try: -... b.next() -... except StopIteration: -... pass +... for elem in b: +... break ... return izip(a, b) ->>> def grouper(n, iterable, padvalue=None): +>>> def grouper(n, iterable, fillvalue=None): ... "grouper(3, 'abcdefg', 'x') --> ('a','b','c'), ('d','e','f'), ('g','x','x')" -... return izip(*[chain(iterable, repeat(padvalue, n-1))]*n) +... args = [iter(iterable)] * n +... kwds = dict(fillvalue=fillvalue) +... return izip_longest(*args, **kwds) >>> def roundrobin(*iterables): ... "roundrobin('abc', 'd', 'ef') --> 'a', 'd', 'e', 'b', 'f', 'c'"