diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst index 68a4ffd174d..3f2abdc661f 100644 --- a/Doc/library/itertools.rst +++ b/Doc/library/itertools.rst @@ -89,6 +89,7 @@ loops that truncate the stream. .. versionadded:: 2.6 + .. function:: combinations(iterable, r) Return successive *r* length combinations of elements in the *iterable*. @@ -123,6 +124,17 @@ loops that truncate the stream. indices[j] = indices[j-1] + 1 yield tuple(pool[i] for i in indices) + The code for :func:`combinations` can be also expressed as a subsequence + of :func:`permutations` after filtering entries where the elements are not + in sorted order (according to their position in the input pool):: + + def combinations(iterable, r): + pool = tuple(iterable) + n = len(pool) + for indices in permutations(range(n), r): + if sorted(indices) == list(indices): + yield tuple(pool[i] for i in indices) + .. versionadded:: 2.6 .. function:: count([n]) @@ -391,6 +403,18 @@ loops that truncate the stream. else: return + The code for :func:`permutations` can be also expressed as a subsequence of + :func:`product`, filtered to exclude entries with repeated elements (those + from the same position in the input pool):: + + def permutations(iterable, r=None): + pool = tuple(iterable) + n = len(pool) + r = n if r is None else r + for indices in product(range(n), repeat=r): + if len(set(indices)) == r: + yield tuple(pool[i] for i in indices) + .. versionadded:: 2.6 .. function:: product(*iterables[, repeat]) @@ -401,9 +425,9 @@ loops that truncate the stream. ``product(A, B)`` returns the same as ``((x,y) for x in A for y in B)``. The leftmost iterators are in the outermost for-loop, so the output tuples - cycle in a manner similar to an odometer (with the rightmost element - changing on every iteration). This results in a lexicographic ordering - so that if the inputs iterables are sorted, the product tuples are emitted + cycle like an odometer (with the rightmost element changing on every + iteration). This results in a lexicographic ordering so that if the + inputs iterables are sorted, the product tuples are emitted in sorted order. To compute the product of an iterable with itself, specify the number of @@ -415,12 +439,11 @@ loops that truncate the stream. def product(*args, **kwds): pools = map(tuple, args) * kwds.get('repeat', 1) - if pools: - result = [[]] - for pool in pools: - result = [x+[y] for x in result for y in pool] - for prod in result: - yield tuple(prod) + result = [[]] + for pool in pools: + result = [x+[y] for x in result for y in pool] + for prod in result: + yield tuple(prod) .. versionadded:: 2.6 diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 087570c93f1..4197989888f 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -40,9 +40,21 @@ def take(n, seq): 'Convenience function for partially consuming a long of infinite iterable' return list(islice(seq, n)) +def prod(iterable): + return reduce(operator.mul, iterable, 1) + def fact(n): 'Factorial' - return reduce(operator.mul, range(1, n+1), 1) + return prod(range(1, n+1)) + +def permutations(iterable, r=None): + # XXX use this until real permutations code is added + pool = tuple(iterable) + n = len(pool) + r = n if r is None else r + for indices in product(range(n), repeat=r): + if len(set(indices)) == r: + yield tuple(pool[i] for i in indices) class TestBasicOps(unittest.TestCase): def test_chain(self): @@ -62,11 +74,38 @@ class TestBasicOps(unittest.TestCase): def test_combinations(self): self.assertRaises(TypeError, combinations, 'abc') # missing r argument self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments + self.assertRaises(TypeError, combinations, None) # pool is not iterable self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big self.assertEqual(list(combinations(range(4), 3)), [(0,1,2), (0,1,3), (0,2,3), (1,2,3)]) - for n in range(8): + + def combinations1(iterable, r): + 'Pure python version shown in the docs' + pool = tuple(iterable) + n = len(pool) + indices = range(r) + yield tuple(pool[i] for i in indices) + while 1: + for i in reversed(range(r)): + if indices[i] != i + n - r: + break + else: + return + indices[i] += 1 + for j in range(i+1, r): + indices[j] = indices[j-1] + 1 + yield tuple(pool[i] for i in indices) + + def combinations2(iterable, r): + 'Pure python version shown in the docs' + pool = tuple(iterable) + n = len(pool) + for indices in permutations(range(n), r): + if sorted(indices) == list(indices): + yield tuple(pool[i] for i in indices) + + for n in range(7): values = [5*x-12 for x in range(n)] for r in range(n+1): result = list(combinations(values, r)) @@ -78,6 +117,73 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(len(set(c)), r) # no duplicate elements self.assertEqual(list(c), sorted(c)) # keep original ordering self.assert_(all(e in values for e in c)) # elements taken from input iterable + self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version + self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version + + # Test implementation detail: tuple re-use + self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1) + self.assertNotEqual(len(set(map(id, list(combinations('abcde', 3))))), 1) + + def test_permutations(self): + self.assertRaises(TypeError, permutations) # too few arguments + self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments +## self.assertRaises(TypeError, permutations, None) # pool is not iterable +## self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative +## self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big + self.assertEqual(list(permutations(range(3), 2)), + [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)]) + + def permutations1(iterable, r=None): + 'Pure python version shown in the docs' + pool = tuple(iterable) + n = len(pool) + r = n if r is None else r + indices = range(n) + cycles = range(n-r+1, n+1)[::-1] + yield tuple(pool[i] for i in indices[:r]) + while n: + for i in reversed(range(r)): + cycles[i] -= 1 + if cycles[i] == 0: + indices[i:] = indices[i+1:] + indices[i:i+1] + cycles[i] = n - i + else: + j = cycles[i] + indices[i], indices[-j] = indices[-j], indices[i] + yield tuple(pool[i] for i in indices[:r]) + break + else: + return + + def permutations2(iterable, r=None): + 'Pure python version shown in the docs' + pool = tuple(iterable) + n = len(pool) + r = n if r is None else r + for indices in product(range(n), repeat=r): + if len(set(indices)) == r: + yield tuple(pool[i] for i in indices) + + for n in range(7): + values = [5*x-12 for x in range(n)] + for r in range(n+1): + result = list(permutations(values, r)) + self.assertEqual(len(result), fact(n) / fact(n-r)) # right number of perms + self.assertEqual(len(result), len(set(result))) # no repeats + self.assertEqual(result, sorted(result)) # lexicographic order + for p in result: + self.assertEqual(len(p), r) # r-length permutations + self.assertEqual(len(set(p)), r) # no duplicate elements + self.assert_(all(e in values for e in p)) # elements taken from input iterable + self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version + self.assertEqual(result, list(permutations2(values, r))) # matches first pure python version + if r == n: + self.assertEqual(result, list(permutations(values, None))) # test r as None + self.assertEqual(result, list(permutations(values))) # test default r + + # Test implementation detail: tuple re-use +## self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1) + self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1) def test_count(self): self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)]) @@ -288,7 +394,7 @@ class TestBasicOps(unittest.TestCase): def test_product(self): for args, result in [ - ([], []), # zero iterables ??? is this correct + ([], [()]), # zero iterables (['ab'], [('a',), ('b',)]), # one iterable ([range(2), range(3)], [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)]), # two iterables ([range(0), range(2), range(3)], []), # first iterable with zero length @@ -305,10 +411,10 @@ class TestBasicOps(unittest.TestCase): set('abcdefg'), range(11), tuple(range(13))] for i in range(100): args = [random.choice(argtypes) for j in range(random.randrange(5))] - n = reduce(operator.mul, map(len, args), 1) if args else 0 - self.assertEqual(len(list(product(*args))), n) + expected_len = prod(map(len, args)) + self.assertEqual(len(list(product(*args))), expected_len) args = map(iter, args) - self.assertEqual(len(list(product(*args))), n) + self.assertEqual(len(list(product(*args))), expected_len) # Test implementation detail: tuple re-use self.assertEqual(len(set(map(id, product('abc', 'def')))), 1) diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index 62b2c19b904..76c848491d0 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -1885,10 +1885,7 @@ product_next(productobject *lz) if (result == NULL) { /* On the first pass, return an initial tuple filled with the - first element from each pool. If any pool is empty, then - whole product is empty and we're already done */ - if (npools == 0) - goto empty; + first element from each pool. */ result = PyTuple_New(npools); if (result == NULL) goto empty;