mirror of https://github.com/python/cpython
Merge in r68394 fixing itertools.permutations() and combinations().
This commit is contained in:
parent
5e4e4278c9
commit
5bad41eefc
|
@ -104,7 +104,9 @@ loops that truncate the stream.
|
||||||
# combinations(range(4), 3) --> 012 013 023 123
|
# combinations(range(4), 3) --> 012 013 023 123
|
||||||
pool = tuple(iterable)
|
pool = tuple(iterable)
|
||||||
n = len(pool)
|
n = len(pool)
|
||||||
indices = range(r)
|
if r > n:
|
||||||
|
return
|
||||||
|
indices = list(range(r))
|
||||||
yield tuple(pool[i] for i in indices)
|
yield tuple(pool[i] for i in indices)
|
||||||
while 1:
|
while 1:
|
||||||
for i in reversed(range(r)):
|
for i in reversed(range(r)):
|
||||||
|
@ -128,6 +130,8 @@ loops that truncate the stream.
|
||||||
if sorted(indices) == list(indices):
|
if sorted(indices) == list(indices):
|
||||||
yield tuple(pool[i] for i in indices)
|
yield tuple(pool[i] for i in indices)
|
||||||
|
|
||||||
|
The number of items returned is ``n! / r! / (n-r)!`` when ``0 <= r <= n``
|
||||||
|
or zero when ``r > n``.
|
||||||
|
|
||||||
.. function:: count([n])
|
.. function:: count([n])
|
||||||
|
|
||||||
|
@ -325,7 +329,9 @@ loops that truncate the stream.
|
||||||
pool = tuple(iterable)
|
pool = tuple(iterable)
|
||||||
n = len(pool)
|
n = len(pool)
|
||||||
r = n if r is None else r
|
r = n if r is None else r
|
||||||
indices = range(n)
|
if r > n:
|
||||||
|
return
|
||||||
|
indices = list(range(n))
|
||||||
cycles = range(n, n-r, -1)
|
cycles = range(n, n-r, -1)
|
||||||
yield tuple(pool[i] for i in indices[:r])
|
yield tuple(pool[i] for i in indices[:r])
|
||||||
while n:
|
while n:
|
||||||
|
@ -354,6 +360,8 @@ loops that truncate the stream.
|
||||||
if len(set(indices)) == r:
|
if len(set(indices)) == r:
|
||||||
yield tuple(pool[i] for i in indices)
|
yield tuple(pool[i] for i in indices)
|
||||||
|
|
||||||
|
The number of items returned is ``n! / (n-r)!`` when ``0 <= r <= n``
|
||||||
|
or zero when ``r > n``.
|
||||||
|
|
||||||
.. function:: product(*iterables[, repeat])
|
.. function:: product(*iterables[, repeat])
|
||||||
|
|
||||||
|
@ -593,7 +601,8 @@ which incur interpreter overhead.
|
||||||
return (d for d, s in zip(data, selectors) if s)
|
return (d for d, s in zip(data, selectors) if s)
|
||||||
|
|
||||||
def combinations_with_replacement(iterable, r):
|
def combinations_with_replacement(iterable, r):
|
||||||
"combinations_with_replacement('ABC', 3) --> AA AB AC BB BC CC"
|
"combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC"
|
||||||
|
# number items returned: (n+r-1)! / r! / (n-1)!
|
||||||
pool = tuple(iterable)
|
pool = tuple(iterable)
|
||||||
n = len(pool)
|
n = len(pool)
|
||||||
indices = [0] * r
|
indices = [0] * r
|
||||||
|
|
|
@ -75,11 +75,11 @@ class TestBasicOps(unittest.TestCase):
|
||||||
self.assertRaises(TypeError, list, chain.from_iterable([2, 3]))
|
self.assertRaises(TypeError, list, chain.from_iterable([2, 3]))
|
||||||
|
|
||||||
def test_combinations(self):
|
def test_combinations(self):
|
||||||
self.assertRaises(TypeError, combinations, 'abc') # missing r argument
|
self.assertRaises(TypeError, combinations, 'abc') # missing r argument
|
||||||
self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments
|
self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments
|
||||||
self.assertRaises(TypeError, combinations, None) # pool is not iterable
|
self.assertRaises(TypeError, combinations, None) # pool is not iterable
|
||||||
self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative
|
self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative
|
||||||
self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big
|
self.assertEqual(list(combinations('abc', 32)), []) # r > n
|
||||||
self.assertEqual(list(combinations(range(4), 3)),
|
self.assertEqual(list(combinations(range(4), 3)),
|
||||||
[(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
|
[(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
|
||||||
|
|
||||||
|
@ -87,6 +87,8 @@ class TestBasicOps(unittest.TestCase):
|
||||||
'Pure python version shown in the docs'
|
'Pure python version shown in the docs'
|
||||||
pool = tuple(iterable)
|
pool = tuple(iterable)
|
||||||
n = len(pool)
|
n = len(pool)
|
||||||
|
if r > n:
|
||||||
|
return
|
||||||
indices = list(range(r))
|
indices = list(range(r))
|
||||||
yield tuple(pool[i] for i in indices)
|
yield tuple(pool[i] for i in indices)
|
||||||
while 1:
|
while 1:
|
||||||
|
@ -110,9 +112,9 @@ class TestBasicOps(unittest.TestCase):
|
||||||
|
|
||||||
for n in range(7):
|
for n in range(7):
|
||||||
values = [5*x-12 for x in range(n)]
|
values = [5*x-12 for x in range(n)]
|
||||||
for r in range(n+1):
|
for r in range(n+2):
|
||||||
result = list(combinations(values, r))
|
result = list(combinations(values, r))
|
||||||
self.assertEqual(len(result), fact(n) / fact(r) / fact(n-r)) # right number of combs
|
self.assertEqual(len(result), 0 if r>n else fact(n) / fact(r) / fact(n-r)) # right number of combs
|
||||||
self.assertEqual(len(result), len(set(result))) # no repeats
|
self.assertEqual(len(result), len(set(result))) # no repeats
|
||||||
self.assertEqual(result, sorted(result)) # lexicographic order
|
self.assertEqual(result, sorted(result)) # lexicographic order
|
||||||
for c in result:
|
for c in result:
|
||||||
|
@ -123,7 +125,7 @@ class TestBasicOps(unittest.TestCase):
|
||||||
self.assertEqual(list(c),
|
self.assertEqual(list(c),
|
||||||
[e for e in values if e in c]) # comb is a subsequence of the input iterable
|
[e for e in values if e in c]) # comb is a subsequence of the input iterable
|
||||||
self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version
|
self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version
|
||||||
self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version
|
self.assertEqual(result, list(combinations2(values, r))) # matches second pure python version
|
||||||
|
|
||||||
# Test implementation detail: tuple re-use
|
# Test implementation detail: tuple re-use
|
||||||
self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1)
|
self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1)
|
||||||
|
@ -134,7 +136,7 @@ class TestBasicOps(unittest.TestCase):
|
||||||
self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
|
self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
|
||||||
self.assertRaises(TypeError, permutations, None) # pool is not iterable
|
self.assertRaises(TypeError, permutations, None) # pool is not iterable
|
||||||
self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
|
self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
|
||||||
self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big
|
self.assertEqual(list(permutations('abc', 32)), []) # r > n
|
||||||
self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None
|
self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None
|
||||||
self.assertEqual(list(permutations(range(3), 2)),
|
self.assertEqual(list(permutations(range(3), 2)),
|
||||||
[(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
|
[(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
|
||||||
|
@ -144,6 +146,8 @@ class TestBasicOps(unittest.TestCase):
|
||||||
pool = tuple(iterable)
|
pool = tuple(iterable)
|
||||||
n = len(pool)
|
n = len(pool)
|
||||||
r = n if r is None else r
|
r = n if r is None else r
|
||||||
|
if r > n:
|
||||||
|
return
|
||||||
indices = list(range(n))
|
indices = list(range(n))
|
||||||
cycles = list(range(n-r+1, n+1))[::-1]
|
cycles = list(range(n-r+1, n+1))[::-1]
|
||||||
yield tuple(pool[i] for i in indices[:r])
|
yield tuple(pool[i] for i in indices[:r])
|
||||||
|
@ -172,9 +176,9 @@ class TestBasicOps(unittest.TestCase):
|
||||||
|
|
||||||
for n in range(7):
|
for n in range(7):
|
||||||
values = [5*x-12 for x in range(n)]
|
values = [5*x-12 for x in range(n)]
|
||||||
for r in range(n+1):
|
for r in range(n+2):
|
||||||
result = list(permutations(values, r))
|
result = list(permutations(values, r))
|
||||||
self.assertEqual(len(result), fact(n) / fact(n-r)) # right number of perms
|
self.assertEqual(len(result), 0 if r>n else fact(n) / fact(n-r)) # right number of perms
|
||||||
self.assertEqual(len(result), len(set(result))) # no repeats
|
self.assertEqual(len(result), len(set(result))) # no repeats
|
||||||
self.assertEqual(result, sorted(result)) # lexicographic order
|
self.assertEqual(result, sorted(result)) # lexicographic order
|
||||||
for p in result:
|
for p in result:
|
||||||
|
@ -182,7 +186,7 @@ class TestBasicOps(unittest.TestCase):
|
||||||
self.assertEqual(len(set(p)), r) # no duplicate elements
|
self.assertEqual(len(set(p)), r) # no duplicate elements
|
||||||
self.assert_(all(e in values for e in p)) # elements taken from input iterable
|
self.assert_(all(e in values for e in p)) # elements taken from input iterable
|
||||||
self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version
|
self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version
|
||||||
self.assertEqual(result, list(permutations2(values, r))) # matches first pure python version
|
self.assertEqual(result, list(permutations2(values, r))) # matches second pure python version
|
||||||
if r == n:
|
if r == n:
|
||||||
self.assertEqual(result, list(permutations(values, None))) # test r as None
|
self.assertEqual(result, list(permutations(values, None))) # test r as None
|
||||||
self.assertEqual(result, list(permutations(values))) # test default r
|
self.assertEqual(result, list(permutations(values))) # test default r
|
||||||
|
@ -1384,6 +1388,26 @@ perform as purported.
|
||||||
>>> list(combinations_with_replacement('abc', 2))
|
>>> list(combinations_with_replacement('abc', 2))
|
||||||
[('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')]
|
[('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')]
|
||||||
|
|
||||||
|
>>> list(combinations_with_replacement('01', 3))
|
||||||
|
[('0', '0', '0'), ('0', '0', '1'), ('0', '1', '1'), ('1', '1', '1')]
|
||||||
|
|
||||||
|
>>> def combinations_with_replacement2(iterable, r):
|
||||||
|
... 'Alternate version that filters from product()'
|
||||||
|
... pool = tuple(iterable)
|
||||||
|
... n = len(pool)
|
||||||
|
... for indices in product(range(n), repeat=r):
|
||||||
|
... if sorted(indices) == list(indices):
|
||||||
|
... yield tuple(pool[i] for i in indices)
|
||||||
|
|
||||||
|
>>> list(combinations_with_replacement('abc', 2)) == list(combinations_with_replacement2('abc', 2))
|
||||||
|
True
|
||||||
|
|
||||||
|
>>> list(combinations_with_replacement('01', 3)) == list(combinations_with_replacement2('01', 3))
|
||||||
|
True
|
||||||
|
|
||||||
|
>>> list(combinations_with_replacement('2310', 6)) == list(combinations_with_replacement2('2310', 6))
|
||||||
|
True
|
||||||
|
|
||||||
>>> list(unique_everseen('AAAABBBCCDAABBB'))
|
>>> list(unique_everseen('AAAABBBCCDAABBB'))
|
||||||
['A', 'B', 'C', 'D']
|
['A', 'B', 'C', 'D']
|
||||||
|
|
||||||
|
|
|
@ -1880,10 +1880,6 @@ combinations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
|
||||||
PyErr_SetString(PyExc_ValueError, "r must be non-negative");
|
PyErr_SetString(PyExc_ValueError, "r must be non-negative");
|
||||||
goto error;
|
goto error;
|
||||||
}
|
}
|
||||||
if (r > n) {
|
|
||||||
PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable");
|
|
||||||
goto error;
|
|
||||||
}
|
|
||||||
|
|
||||||
indices = PyMem_Malloc(r * sizeof(Py_ssize_t));
|
indices = PyMem_Malloc(r * sizeof(Py_ssize_t));
|
||||||
if (indices == NULL) {
|
if (indices == NULL) {
|
||||||
|
@ -1903,7 +1899,7 @@ combinations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
|
||||||
co->indices = indices;
|
co->indices = indices;
|
||||||
co->result = NULL;
|
co->result = NULL;
|
||||||
co->r = r;
|
co->r = r;
|
||||||
co->stopped = 0;
|
co->stopped = r > n ? 1 : 0;
|
||||||
|
|
||||||
return (PyObject *)co;
|
return (PyObject *)co;
|
||||||
|
|
||||||
|
@ -2143,10 +2139,6 @@ permutations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
|
||||||
PyErr_SetString(PyExc_ValueError, "r must be non-negative");
|
PyErr_SetString(PyExc_ValueError, "r must be non-negative");
|
||||||
goto error;
|
goto error;
|
||||||
}
|
}
|
||||||
if (r > n) {
|
|
||||||
PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable");
|
|
||||||
goto error;
|
|
||||||
}
|
|
||||||
|
|
||||||
indices = PyMem_Malloc(n * sizeof(Py_ssize_t));
|
indices = PyMem_Malloc(n * sizeof(Py_ssize_t));
|
||||||
cycles = PyMem_Malloc(r * sizeof(Py_ssize_t));
|
cycles = PyMem_Malloc(r * sizeof(Py_ssize_t));
|
||||||
|
@ -2170,7 +2162,7 @@ permutations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
|
||||||
po->cycles = cycles;
|
po->cycles = cycles;
|
||||||
po->result = NULL;
|
po->result = NULL;
|
||||||
po->r = r;
|
po->r = r;
|
||||||
po->stopped = 0;
|
po->stopped = r > n ? 1 : 0;
|
||||||
|
|
||||||
return (PyObject *)po;
|
return (PyObject *)po;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue