Beef-up docs and tests for itertools. Fix-up end-case for product().

This commit is contained in:
Raymond Hettinger 2008-03-04 04:17:08 +00:00
parent 378586a844
commit d553d856e7
3 changed files with 145 additions and 19 deletions

View File

@ -89,6 +89,7 @@ loops that truncate the stream.
.. versionadded:: 2.6
.. function:: combinations(iterable, r)
Return successive *r* length combinations of elements in the *iterable*.
@ -123,6 +124,17 @@ loops that truncate the stream.
indices[j] = indices[j-1] + 1
yield tuple(pool[i] for i in indices)
The code for :func:`combinations` can be also expressed as a subsequence
of :func:`permutations` after filtering entries where the elements are not
in sorted order (according to their position in the input pool)::
def combinations(iterable, r):
pool = tuple(iterable)
n = len(pool)
for indices in permutations(range(n), r):
if sorted(indices) == list(indices):
yield tuple(pool[i] for i in indices)
.. versionadded:: 2.6
.. function:: count([n])
@ -391,6 +403,18 @@ loops that truncate the stream.
else:
return
The code for :func:`permutations` can be also expressed as a subsequence of
:func:`product`, filtered to exclude entries with repeated elements (those
from the same position in the input pool)::
def permutations(iterable, r=None):
pool = tuple(iterable)
n = len(pool)
r = n if r is None else r
for indices in product(range(n), repeat=r):
if len(set(indices)) == r:
yield tuple(pool[i] for i in indices)
.. versionadded:: 2.6
.. function:: product(*iterables[, repeat])
@ -401,9 +425,9 @@ loops that truncate the stream.
``product(A, B)`` returns the same as ``((x,y) for x in A for y in B)``.
The leftmost iterators are in the outermost for-loop, so the output tuples
cycle in a manner similar to an odometer (with the rightmost element
changing on every iteration). This results in a lexicographic ordering
so that if the inputs iterables are sorted, the product tuples are emitted
cycle like an odometer (with the rightmost element changing on every
iteration). This results in a lexicographic ordering so that if the
inputs iterables are sorted, the product tuples are emitted
in sorted order.
To compute the product of an iterable with itself, specify the number of
@ -415,12 +439,11 @@ loops that truncate the stream.
def product(*args, **kwds):
pools = map(tuple, args) * kwds.get('repeat', 1)
if pools:
result = [[]]
for pool in pools:
result = [x+[y] for x in result for y in pool]
for prod in result:
yield tuple(prod)
result = [[]]
for pool in pools:
result = [x+[y] for x in result for y in pool]
for prod in result:
yield tuple(prod)
.. versionadded:: 2.6

View File

@ -40,9 +40,21 @@ def take(n, seq):
'Convenience function for partially consuming a long of infinite iterable'
return list(islice(seq, n))
def prod(iterable):
return reduce(operator.mul, iterable, 1)
def fact(n):
'Factorial'
return reduce(operator.mul, range(1, n+1), 1)
return prod(range(1, n+1))
def permutations(iterable, r=None):
# XXX use this until real permutations code is added
pool = tuple(iterable)
n = len(pool)
r = n if r is None else r
for indices in product(range(n), repeat=r):
if len(set(indices)) == r:
yield tuple(pool[i] for i in indices)
class TestBasicOps(unittest.TestCase):
def test_chain(self):
@ -62,11 +74,38 @@ class TestBasicOps(unittest.TestCase):
def test_combinations(self):
self.assertRaises(TypeError, combinations, 'abc') # missing r argument
self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments
self.assertRaises(TypeError, combinations, None) # pool is not iterable
self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative
self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big
self.assertEqual(list(combinations(range(4), 3)),
[(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
for n in range(8):
def combinations1(iterable, r):
'Pure python version shown in the docs'
pool = tuple(iterable)
n = len(pool)
indices = range(r)
yield tuple(pool[i] for i in indices)
while 1:
for i in reversed(range(r)):
if indices[i] != i + n - r:
break
else:
return
indices[i] += 1
for j in range(i+1, r):
indices[j] = indices[j-1] + 1
yield tuple(pool[i] for i in indices)
def combinations2(iterable, r):
'Pure python version shown in the docs'
pool = tuple(iterable)
n = len(pool)
for indices in permutations(range(n), r):
if sorted(indices) == list(indices):
yield tuple(pool[i] for i in indices)
for n in range(7):
values = [5*x-12 for x in range(n)]
for r in range(n+1):
result = list(combinations(values, r))
@ -78,6 +117,73 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(len(set(c)), r) # no duplicate elements
self.assertEqual(list(c), sorted(c)) # keep original ordering
self.assert_(all(e in values for e in c)) # elements taken from input iterable
self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version
self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version
# Test implementation detail: tuple re-use
self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1)
self.assertNotEqual(len(set(map(id, list(combinations('abcde', 3))))), 1)
def test_permutations(self):
self.assertRaises(TypeError, permutations) # too few arguments
self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
## self.assertRaises(TypeError, permutations, None) # pool is not iterable
## self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
## self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big
self.assertEqual(list(permutations(range(3), 2)),
[(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
def permutations1(iterable, r=None):
'Pure python version shown in the docs'
pool = tuple(iterable)
n = len(pool)
r = n if r is None else r
indices = range(n)
cycles = range(n-r+1, n+1)[::-1]
yield tuple(pool[i] for i in indices[:r])
while n:
for i in reversed(range(r)):
cycles[i] -= 1
if cycles[i] == 0:
indices[i:] = indices[i+1:] + indices[i:i+1]
cycles[i] = n - i
else:
j = cycles[i]
indices[i], indices[-j] = indices[-j], indices[i]
yield tuple(pool[i] for i in indices[:r])
break
else:
return
def permutations2(iterable, r=None):
'Pure python version shown in the docs'
pool = tuple(iterable)
n = len(pool)
r = n if r is None else r
for indices in product(range(n), repeat=r):
if len(set(indices)) == r:
yield tuple(pool[i] for i in indices)
for n in range(7):
values = [5*x-12 for x in range(n)]
for r in range(n+1):
result = list(permutations(values, r))
self.assertEqual(len(result), fact(n) / fact(n-r)) # right number of perms
self.assertEqual(len(result), len(set(result))) # no repeats
self.assertEqual(result, sorted(result)) # lexicographic order
for p in result:
self.assertEqual(len(p), r) # r-length permutations
self.assertEqual(len(set(p)), r) # no duplicate elements
self.assert_(all(e in values for e in p)) # elements taken from input iterable
self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version
self.assertEqual(result, list(permutations2(values, r))) # matches first pure python version
if r == n:
self.assertEqual(result, list(permutations(values, None))) # test r as None
self.assertEqual(result, list(permutations(values))) # test default r
# Test implementation detail: tuple re-use
## self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1)
def test_count(self):
self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)])
@ -288,7 +394,7 @@ class TestBasicOps(unittest.TestCase):
def test_product(self):
for args, result in [
([], []), # zero iterables ??? is this correct
([], [()]), # zero iterables
(['ab'], [('a',), ('b',)]), # one iterable
([range(2), range(3)], [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)]), # two iterables
([range(0), range(2), range(3)], []), # first iterable with zero length
@ -305,10 +411,10 @@ class TestBasicOps(unittest.TestCase):
set('abcdefg'), range(11), tuple(range(13))]
for i in range(100):
args = [random.choice(argtypes) for j in range(random.randrange(5))]
n = reduce(operator.mul, map(len, args), 1) if args else 0
self.assertEqual(len(list(product(*args))), n)
expected_len = prod(map(len, args))
self.assertEqual(len(list(product(*args))), expected_len)
args = map(iter, args)
self.assertEqual(len(list(product(*args))), n)
self.assertEqual(len(list(product(*args))), expected_len)
# Test implementation detail: tuple re-use
self.assertEqual(len(set(map(id, product('abc', 'def')))), 1)

View File

@ -1885,10 +1885,7 @@ product_next(productobject *lz)
if (result == NULL) {
/* On the first pass, return an initial tuple filled with the
first element from each pool. If any pool is empty, then
whole product is empty and we're already done */
if (npools == 0)
goto empty;
first element from each pool. */
result = PyTuple_New(npools);
if (result == NULL)
goto empty;