mirror of https://github.com/python/cpython
Beef-up docs and tests for itertools. Fix-up end-case for product().
This commit is contained in:
parent
378586a844
commit
d553d856e7
|
@ -89,6 +89,7 @@ loops that truncate the stream.
|
||||||
|
|
||||||
.. versionadded:: 2.6
|
.. versionadded:: 2.6
|
||||||
|
|
||||||
|
|
||||||
.. function:: combinations(iterable, r)
|
.. function:: combinations(iterable, r)
|
||||||
|
|
||||||
Return successive *r* length combinations of elements in the *iterable*.
|
Return successive *r* length combinations of elements in the *iterable*.
|
||||||
|
@ -123,6 +124,17 @@ loops that truncate the stream.
|
||||||
indices[j] = indices[j-1] + 1
|
indices[j] = indices[j-1] + 1
|
||||||
yield tuple(pool[i] for i in indices)
|
yield tuple(pool[i] for i in indices)
|
||||||
|
|
||||||
|
The code for :func:`combinations` can be also expressed as a subsequence
|
||||||
|
of :func:`permutations` after filtering entries where the elements are not
|
||||||
|
in sorted order (according to their position in the input pool)::
|
||||||
|
|
||||||
|
def combinations(iterable, r):
|
||||||
|
pool = tuple(iterable)
|
||||||
|
n = len(pool)
|
||||||
|
for indices in permutations(range(n), r):
|
||||||
|
if sorted(indices) == list(indices):
|
||||||
|
yield tuple(pool[i] for i in indices)
|
||||||
|
|
||||||
.. versionadded:: 2.6
|
.. versionadded:: 2.6
|
||||||
|
|
||||||
.. function:: count([n])
|
.. function:: count([n])
|
||||||
|
@ -391,6 +403,18 @@ loops that truncate the stream.
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
The code for :func:`permutations` can be also expressed as a subsequence of
|
||||||
|
:func:`product`, filtered to exclude entries with repeated elements (those
|
||||||
|
from the same position in the input pool)::
|
||||||
|
|
||||||
|
def permutations(iterable, r=None):
|
||||||
|
pool = tuple(iterable)
|
||||||
|
n = len(pool)
|
||||||
|
r = n if r is None else r
|
||||||
|
for indices in product(range(n), repeat=r):
|
||||||
|
if len(set(indices)) == r:
|
||||||
|
yield tuple(pool[i] for i in indices)
|
||||||
|
|
||||||
.. versionadded:: 2.6
|
.. versionadded:: 2.6
|
||||||
|
|
||||||
.. function:: product(*iterables[, repeat])
|
.. function:: product(*iterables[, repeat])
|
||||||
|
@ -401,9 +425,9 @@ loops that truncate the stream.
|
||||||
``product(A, B)`` returns the same as ``((x,y) for x in A for y in B)``.
|
``product(A, B)`` returns the same as ``((x,y) for x in A for y in B)``.
|
||||||
|
|
||||||
The leftmost iterators are in the outermost for-loop, so the output tuples
|
The leftmost iterators are in the outermost for-loop, so the output tuples
|
||||||
cycle in a manner similar to an odometer (with the rightmost element
|
cycle like an odometer (with the rightmost element changing on every
|
||||||
changing on every iteration). This results in a lexicographic ordering
|
iteration). This results in a lexicographic ordering so that if the
|
||||||
so that if the inputs iterables are sorted, the product tuples are emitted
|
inputs iterables are sorted, the product tuples are emitted
|
||||||
in sorted order.
|
in sorted order.
|
||||||
|
|
||||||
To compute the product of an iterable with itself, specify the number of
|
To compute the product of an iterable with itself, specify the number of
|
||||||
|
@ -415,7 +439,6 @@ loops that truncate the stream.
|
||||||
|
|
||||||
def product(*args, **kwds):
|
def product(*args, **kwds):
|
||||||
pools = map(tuple, args) * kwds.get('repeat', 1)
|
pools = map(tuple, args) * kwds.get('repeat', 1)
|
||||||
if pools:
|
|
||||||
result = [[]]
|
result = [[]]
|
||||||
for pool in pools:
|
for pool in pools:
|
||||||
result = [x+[y] for x in result for y in pool]
|
result = [x+[y] for x in result for y in pool]
|
||||||
|
|
|
@ -40,9 +40,21 @@ def take(n, seq):
|
||||||
'Convenience function for partially consuming a long of infinite iterable'
|
'Convenience function for partially consuming a long of infinite iterable'
|
||||||
return list(islice(seq, n))
|
return list(islice(seq, n))
|
||||||
|
|
||||||
|
def prod(iterable):
|
||||||
|
return reduce(operator.mul, iterable, 1)
|
||||||
|
|
||||||
def fact(n):
|
def fact(n):
|
||||||
'Factorial'
|
'Factorial'
|
||||||
return reduce(operator.mul, range(1, n+1), 1)
|
return prod(range(1, n+1))
|
||||||
|
|
||||||
|
def permutations(iterable, r=None):
|
||||||
|
# XXX use this until real permutations code is added
|
||||||
|
pool = tuple(iterable)
|
||||||
|
n = len(pool)
|
||||||
|
r = n if r is None else r
|
||||||
|
for indices in product(range(n), repeat=r):
|
||||||
|
if len(set(indices)) == r:
|
||||||
|
yield tuple(pool[i] for i in indices)
|
||||||
|
|
||||||
class TestBasicOps(unittest.TestCase):
|
class TestBasicOps(unittest.TestCase):
|
||||||
def test_chain(self):
|
def test_chain(self):
|
||||||
|
@ -62,11 +74,38 @@ class TestBasicOps(unittest.TestCase):
|
||||||
def test_combinations(self):
|
def test_combinations(self):
|
||||||
self.assertRaises(TypeError, combinations, 'abc') # missing r argument
|
self.assertRaises(TypeError, combinations, 'abc') # missing r argument
|
||||||
self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments
|
self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments
|
||||||
|
self.assertRaises(TypeError, combinations, None) # pool is not iterable
|
||||||
self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative
|
self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative
|
||||||
self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big
|
self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big
|
||||||
self.assertEqual(list(combinations(range(4), 3)),
|
self.assertEqual(list(combinations(range(4), 3)),
|
||||||
[(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
|
[(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
|
||||||
for n in range(8):
|
|
||||||
|
def combinations1(iterable, r):
|
||||||
|
'Pure python version shown in the docs'
|
||||||
|
pool = tuple(iterable)
|
||||||
|
n = len(pool)
|
||||||
|
indices = range(r)
|
||||||
|
yield tuple(pool[i] for i in indices)
|
||||||
|
while 1:
|
||||||
|
for i in reversed(range(r)):
|
||||||
|
if indices[i] != i + n - r:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
indices[i] += 1
|
||||||
|
for j in range(i+1, r):
|
||||||
|
indices[j] = indices[j-1] + 1
|
||||||
|
yield tuple(pool[i] for i in indices)
|
||||||
|
|
||||||
|
def combinations2(iterable, r):
|
||||||
|
'Pure python version shown in the docs'
|
||||||
|
pool = tuple(iterable)
|
||||||
|
n = len(pool)
|
||||||
|
for indices in permutations(range(n), r):
|
||||||
|
if sorted(indices) == list(indices):
|
||||||
|
yield tuple(pool[i] for i in indices)
|
||||||
|
|
||||||
|
for n in range(7):
|
||||||
values = [5*x-12 for x in range(n)]
|
values = [5*x-12 for x in range(n)]
|
||||||
for r in range(n+1):
|
for r in range(n+1):
|
||||||
result = list(combinations(values, r))
|
result = list(combinations(values, r))
|
||||||
|
@ -78,6 +117,73 @@ class TestBasicOps(unittest.TestCase):
|
||||||
self.assertEqual(len(set(c)), r) # no duplicate elements
|
self.assertEqual(len(set(c)), r) # no duplicate elements
|
||||||
self.assertEqual(list(c), sorted(c)) # keep original ordering
|
self.assertEqual(list(c), sorted(c)) # keep original ordering
|
||||||
self.assert_(all(e in values for e in c)) # elements taken from input iterable
|
self.assert_(all(e in values for e in c)) # elements taken from input iterable
|
||||||
|
self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version
|
||||||
|
self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version
|
||||||
|
|
||||||
|
# Test implementation detail: tuple re-use
|
||||||
|
self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1)
|
||||||
|
self.assertNotEqual(len(set(map(id, list(combinations('abcde', 3))))), 1)
|
||||||
|
|
||||||
|
def test_permutations(self):
|
||||||
|
self.assertRaises(TypeError, permutations) # too few arguments
|
||||||
|
self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
|
||||||
|
## self.assertRaises(TypeError, permutations, None) # pool is not iterable
|
||||||
|
## self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
|
||||||
|
## self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big
|
||||||
|
self.assertEqual(list(permutations(range(3), 2)),
|
||||||
|
[(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
|
||||||
|
|
||||||
|
def permutations1(iterable, r=None):
|
||||||
|
'Pure python version shown in the docs'
|
||||||
|
pool = tuple(iterable)
|
||||||
|
n = len(pool)
|
||||||
|
r = n if r is None else r
|
||||||
|
indices = range(n)
|
||||||
|
cycles = range(n-r+1, n+1)[::-1]
|
||||||
|
yield tuple(pool[i] for i in indices[:r])
|
||||||
|
while n:
|
||||||
|
for i in reversed(range(r)):
|
||||||
|
cycles[i] -= 1
|
||||||
|
if cycles[i] == 0:
|
||||||
|
indices[i:] = indices[i+1:] + indices[i:i+1]
|
||||||
|
cycles[i] = n - i
|
||||||
|
else:
|
||||||
|
j = cycles[i]
|
||||||
|
indices[i], indices[-j] = indices[-j], indices[i]
|
||||||
|
yield tuple(pool[i] for i in indices[:r])
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
def permutations2(iterable, r=None):
|
||||||
|
'Pure python version shown in the docs'
|
||||||
|
pool = tuple(iterable)
|
||||||
|
n = len(pool)
|
||||||
|
r = n if r is None else r
|
||||||
|
for indices in product(range(n), repeat=r):
|
||||||
|
if len(set(indices)) == r:
|
||||||
|
yield tuple(pool[i] for i in indices)
|
||||||
|
|
||||||
|
for n in range(7):
|
||||||
|
values = [5*x-12 for x in range(n)]
|
||||||
|
for r in range(n+1):
|
||||||
|
result = list(permutations(values, r))
|
||||||
|
self.assertEqual(len(result), fact(n) / fact(n-r)) # right number of perms
|
||||||
|
self.assertEqual(len(result), len(set(result))) # no repeats
|
||||||
|
self.assertEqual(result, sorted(result)) # lexicographic order
|
||||||
|
for p in result:
|
||||||
|
self.assertEqual(len(p), r) # r-length permutations
|
||||||
|
self.assertEqual(len(set(p)), r) # no duplicate elements
|
||||||
|
self.assert_(all(e in values for e in p)) # elements taken from input iterable
|
||||||
|
self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version
|
||||||
|
self.assertEqual(result, list(permutations2(values, r))) # matches first pure python version
|
||||||
|
if r == n:
|
||||||
|
self.assertEqual(result, list(permutations(values, None))) # test r as None
|
||||||
|
self.assertEqual(result, list(permutations(values))) # test default r
|
||||||
|
|
||||||
|
# Test implementation detail: tuple re-use
|
||||||
|
## self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
|
||||||
|
self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1)
|
||||||
|
|
||||||
def test_count(self):
|
def test_count(self):
|
||||||
self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)])
|
self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)])
|
||||||
|
@ -288,7 +394,7 @@ class TestBasicOps(unittest.TestCase):
|
||||||
|
|
||||||
def test_product(self):
|
def test_product(self):
|
||||||
for args, result in [
|
for args, result in [
|
||||||
([], []), # zero iterables ??? is this correct
|
([], [()]), # zero iterables
|
||||||
(['ab'], [('a',), ('b',)]), # one iterable
|
(['ab'], [('a',), ('b',)]), # one iterable
|
||||||
([range(2), range(3)], [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)]), # two iterables
|
([range(2), range(3)], [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)]), # two iterables
|
||||||
([range(0), range(2), range(3)], []), # first iterable with zero length
|
([range(0), range(2), range(3)], []), # first iterable with zero length
|
||||||
|
@ -305,10 +411,10 @@ class TestBasicOps(unittest.TestCase):
|
||||||
set('abcdefg'), range(11), tuple(range(13))]
|
set('abcdefg'), range(11), tuple(range(13))]
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
args = [random.choice(argtypes) for j in range(random.randrange(5))]
|
args = [random.choice(argtypes) for j in range(random.randrange(5))]
|
||||||
n = reduce(operator.mul, map(len, args), 1) if args else 0
|
expected_len = prod(map(len, args))
|
||||||
self.assertEqual(len(list(product(*args))), n)
|
self.assertEqual(len(list(product(*args))), expected_len)
|
||||||
args = map(iter, args)
|
args = map(iter, args)
|
||||||
self.assertEqual(len(list(product(*args))), n)
|
self.assertEqual(len(list(product(*args))), expected_len)
|
||||||
|
|
||||||
# Test implementation detail: tuple re-use
|
# Test implementation detail: tuple re-use
|
||||||
self.assertEqual(len(set(map(id, product('abc', 'def')))), 1)
|
self.assertEqual(len(set(map(id, product('abc', 'def')))), 1)
|
||||||
|
|
|
@ -1885,10 +1885,7 @@ product_next(productobject *lz)
|
||||||
|
|
||||||
if (result == NULL) {
|
if (result == NULL) {
|
||||||
/* On the first pass, return an initial tuple filled with the
|
/* On the first pass, return an initial tuple filled with the
|
||||||
first element from each pool. If any pool is empty, then
|
first element from each pool. */
|
||||||
whole product is empty and we're already done */
|
|
||||||
if (npools == 0)
|
|
||||||
goto empty;
|
|
||||||
result = PyTuple_New(npools);
|
result = PyTuple_New(npools);
|
||||||
if (result == NULL)
|
if (result == NULL)
|
||||||
goto empty;
|
goto empty;
|
||||||
|
|
Loading…
Reference in New Issue