Beef-up docs and tests for itertools. Fix-up end-case for product().

This commit is contained in:
Raymond Hettinger 2008-03-04 04:17:08 +00:00
parent 378586a844
commit d553d856e7
3 changed files with 145 additions and 19 deletions

View File

@ -89,6 +89,7 @@ loops that truncate the stream.
.. versionadded:: 2.6 .. versionadded:: 2.6
.. function:: combinations(iterable, r) .. function:: combinations(iterable, r)
Return successive *r* length combinations of elements in the *iterable*. Return successive *r* length combinations of elements in the *iterable*.
@ -123,6 +124,17 @@ loops that truncate the stream.
indices[j] = indices[j-1] + 1 indices[j] = indices[j-1] + 1
yield tuple(pool[i] for i in indices) yield tuple(pool[i] for i in indices)
The code for :func:`combinations` can be also expressed as a subsequence
of :func:`permutations` after filtering entries where the elements are not
in sorted order (according to their position in the input pool)::
def combinations(iterable, r):
pool = tuple(iterable)
n = len(pool)
for indices in permutations(range(n), r):
if sorted(indices) == list(indices):
yield tuple(pool[i] for i in indices)
.. versionadded:: 2.6 .. versionadded:: 2.6
.. function:: count([n]) .. function:: count([n])
@ -391,6 +403,18 @@ loops that truncate the stream.
else: else:
return return
The code for :func:`permutations` can be also expressed as a subsequence of
:func:`product`, filtered to exclude entries with repeated elements (those
from the same position in the input pool)::
def permutations(iterable, r=None):
pool = tuple(iterable)
n = len(pool)
r = n if r is None else r
for indices in product(range(n), repeat=r):
if len(set(indices)) == r:
yield tuple(pool[i] for i in indices)
.. versionadded:: 2.6 .. versionadded:: 2.6
.. function:: product(*iterables[, repeat]) .. function:: product(*iterables[, repeat])
@ -401,9 +425,9 @@ loops that truncate the stream.
``product(A, B)`` returns the same as ``((x,y) for x in A for y in B)``. ``product(A, B)`` returns the same as ``((x,y) for x in A for y in B)``.
The leftmost iterators are in the outermost for-loop, so the output tuples The leftmost iterators are in the outermost for-loop, so the output tuples
cycle in a manner similar to an odometer (with the rightmost element cycle like an odometer (with the rightmost element changing on every
changing on every iteration). This results in a lexicographic ordering iteration). This results in a lexicographic ordering so that if the
so that if the inputs iterables are sorted, the product tuples are emitted inputs iterables are sorted, the product tuples are emitted
in sorted order. in sorted order.
To compute the product of an iterable with itself, specify the number of To compute the product of an iterable with itself, specify the number of
@ -415,7 +439,6 @@ loops that truncate the stream.
def product(*args, **kwds): def product(*args, **kwds):
pools = map(tuple, args) * kwds.get('repeat', 1) pools = map(tuple, args) * kwds.get('repeat', 1)
if pools:
result = [[]] result = [[]]
for pool in pools: for pool in pools:
result = [x+[y] for x in result for y in pool] result = [x+[y] for x in result for y in pool]

View File

@ -40,9 +40,21 @@ def take(n, seq):
'Convenience function for partially consuming a long of infinite iterable' 'Convenience function for partially consuming a long of infinite iterable'
return list(islice(seq, n)) return list(islice(seq, n))
def prod(iterable):
return reduce(operator.mul, iterable, 1)
def fact(n): def fact(n):
'Factorial' 'Factorial'
return reduce(operator.mul, range(1, n+1), 1) return prod(range(1, n+1))
def permutations(iterable, r=None):
# XXX use this until real permutations code is added
pool = tuple(iterable)
n = len(pool)
r = n if r is None else r
for indices in product(range(n), repeat=r):
if len(set(indices)) == r:
yield tuple(pool[i] for i in indices)
class TestBasicOps(unittest.TestCase): class TestBasicOps(unittest.TestCase):
def test_chain(self): def test_chain(self):
@ -62,11 +74,38 @@ class TestBasicOps(unittest.TestCase):
def test_combinations(self): def test_combinations(self):
self.assertRaises(TypeError, combinations, 'abc') # missing r argument self.assertRaises(TypeError, combinations, 'abc') # missing r argument
self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments
self.assertRaises(TypeError, combinations, None) # pool is not iterable
self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative
self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big
self.assertEqual(list(combinations(range(4), 3)), self.assertEqual(list(combinations(range(4), 3)),
[(0,1,2), (0,1,3), (0,2,3), (1,2,3)]) [(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
for n in range(8):
def combinations1(iterable, r):
'Pure python version shown in the docs'
pool = tuple(iterable)
n = len(pool)
indices = range(r)
yield tuple(pool[i] for i in indices)
while 1:
for i in reversed(range(r)):
if indices[i] != i + n - r:
break
else:
return
indices[i] += 1
for j in range(i+1, r):
indices[j] = indices[j-1] + 1
yield tuple(pool[i] for i in indices)
def combinations2(iterable, r):
'Pure python version shown in the docs'
pool = tuple(iterable)
n = len(pool)
for indices in permutations(range(n), r):
if sorted(indices) == list(indices):
yield tuple(pool[i] for i in indices)
for n in range(7):
values = [5*x-12 for x in range(n)] values = [5*x-12 for x in range(n)]
for r in range(n+1): for r in range(n+1):
result = list(combinations(values, r)) result = list(combinations(values, r))
@ -78,6 +117,73 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(len(set(c)), r) # no duplicate elements self.assertEqual(len(set(c)), r) # no duplicate elements
self.assertEqual(list(c), sorted(c)) # keep original ordering self.assertEqual(list(c), sorted(c)) # keep original ordering
self.assert_(all(e in values for e in c)) # elements taken from input iterable self.assert_(all(e in values for e in c)) # elements taken from input iterable
self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version
self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version
# Test implementation detail: tuple re-use
self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1)
self.assertNotEqual(len(set(map(id, list(combinations('abcde', 3))))), 1)
def test_permutations(self):
self.assertRaises(TypeError, permutations) # too few arguments
self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
## self.assertRaises(TypeError, permutations, None) # pool is not iterable
## self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
## self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big
self.assertEqual(list(permutations(range(3), 2)),
[(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
def permutations1(iterable, r=None):
'Pure python version shown in the docs'
pool = tuple(iterable)
n = len(pool)
r = n if r is None else r
indices = range(n)
cycles = range(n-r+1, n+1)[::-1]
yield tuple(pool[i] for i in indices[:r])
while n:
for i in reversed(range(r)):
cycles[i] -= 1
if cycles[i] == 0:
indices[i:] = indices[i+1:] + indices[i:i+1]
cycles[i] = n - i
else:
j = cycles[i]
indices[i], indices[-j] = indices[-j], indices[i]
yield tuple(pool[i] for i in indices[:r])
break
else:
return
def permutations2(iterable, r=None):
'Pure python version shown in the docs'
pool = tuple(iterable)
n = len(pool)
r = n if r is None else r
for indices in product(range(n), repeat=r):
if len(set(indices)) == r:
yield tuple(pool[i] for i in indices)
for n in range(7):
values = [5*x-12 for x in range(n)]
for r in range(n+1):
result = list(permutations(values, r))
self.assertEqual(len(result), fact(n) / fact(n-r)) # right number of perms
self.assertEqual(len(result), len(set(result))) # no repeats
self.assertEqual(result, sorted(result)) # lexicographic order
for p in result:
self.assertEqual(len(p), r) # r-length permutations
self.assertEqual(len(set(p)), r) # no duplicate elements
self.assert_(all(e in values for e in p)) # elements taken from input iterable
self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version
self.assertEqual(result, list(permutations2(values, r))) # matches first pure python version
if r == n:
self.assertEqual(result, list(permutations(values, None))) # test r as None
self.assertEqual(result, list(permutations(values))) # test default r
# Test implementation detail: tuple re-use
## self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1)
def test_count(self): def test_count(self):
self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)]) self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)])
@ -288,7 +394,7 @@ class TestBasicOps(unittest.TestCase):
def test_product(self): def test_product(self):
for args, result in [ for args, result in [
([], []), # zero iterables ??? is this correct ([], [()]), # zero iterables
(['ab'], [('a',), ('b',)]), # one iterable (['ab'], [('a',), ('b',)]), # one iterable
([range(2), range(3)], [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)]), # two iterables ([range(2), range(3)], [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)]), # two iterables
([range(0), range(2), range(3)], []), # first iterable with zero length ([range(0), range(2), range(3)], []), # first iterable with zero length
@ -305,10 +411,10 @@ class TestBasicOps(unittest.TestCase):
set('abcdefg'), range(11), tuple(range(13))] set('abcdefg'), range(11), tuple(range(13))]
for i in range(100): for i in range(100):
args = [random.choice(argtypes) for j in range(random.randrange(5))] args = [random.choice(argtypes) for j in range(random.randrange(5))]
n = reduce(operator.mul, map(len, args), 1) if args else 0 expected_len = prod(map(len, args))
self.assertEqual(len(list(product(*args))), n) self.assertEqual(len(list(product(*args))), expected_len)
args = map(iter, args) args = map(iter, args)
self.assertEqual(len(list(product(*args))), n) self.assertEqual(len(list(product(*args))), expected_len)
# Test implementation detail: tuple re-use # Test implementation detail: tuple re-use
self.assertEqual(len(set(map(id, product('abc', 'def')))), 1) self.assertEqual(len(set(map(id, product('abc', 'def')))), 1)

View File

@ -1885,10 +1885,7 @@ product_next(productobject *lz)
if (result == NULL) { if (result == NULL) {
/* On the first pass, return an initial tuple filled with the /* On the first pass, return an initial tuple filled with the
first element from each pool. If any pool is empty, then first element from each pool. */
whole product is empty and we're already done */
if (npools == 0)
goto empty;
result = PyTuple_New(npools); result = PyTuple_New(npools);
if (result == NULL) if (result == NULL)
goto empty; goto empty;