From 5bad41eefc9f80298bb1abe00a5475be8b015c57 Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Thu, 8 Jan 2009 21:01:54 +0000 Subject: [PATCH] Merge in r68394 fixing itertools.permutations() and combinations(). --- Doc/library/itertools.rst | 15 +++++++++++--- Lib/test/test_itertools.py | 42 ++++++++++++++++++++++++++++++-------- Modules/itertoolsmodule.c | 12 ++--------- 3 files changed, 47 insertions(+), 22 deletions(-) diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst index b384cc8735b..db10b6d1ac5 100644 --- a/Doc/library/itertools.rst +++ b/Doc/library/itertools.rst @@ -104,7 +104,9 @@ loops that truncate the stream. # combinations(range(4), 3) --> 012 013 023 123 pool = tuple(iterable) n = len(pool) - indices = range(r) + if r > n: + return + indices = list(range(r)) yield tuple(pool[i] for i in indices) while 1: for i in reversed(range(r)): @@ -128,6 +130,8 @@ loops that truncate the stream. if sorted(indices) == list(indices): yield tuple(pool[i] for i in indices) + The number of items returned is ``n! / r! / (n-r)!`` when ``0 <= r <= n`` + or zero when ``r > n``. .. function:: count([n]) @@ -325,7 +329,9 @@ loops that truncate the stream. pool = tuple(iterable) n = len(pool) r = n if r is None else r - indices = range(n) + if r > n: + return + indices = list(range(n)) cycles = range(n, n-r, -1) yield tuple(pool[i] for i in indices[:r]) while n: @@ -354,6 +360,8 @@ loops that truncate the stream. if len(set(indices)) == r: yield tuple(pool[i] for i in indices) + The number of items returned is ``n! / (n-r)!`` when ``0 <= r <= n`` + or zero when ``r > n``. .. function:: product(*iterables[, repeat]) @@ -593,7 +601,8 @@ which incur interpreter overhead. return (d for d, s in zip(data, selectors) if s) def combinations_with_replacement(iterable, r): - "combinations_with_replacement('ABC', 3) --> AA AB AC BB BC CC" + "combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC" + # number items returned: (n+r-1)! / r! / (n-1)! pool = tuple(iterable) n = len(pool) indices = [0] * r diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 80424f813a9..ba55d23e95c 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -75,11 +75,11 @@ class TestBasicOps(unittest.TestCase): self.assertRaises(TypeError, list, chain.from_iterable([2, 3])) def test_combinations(self): - self.assertRaises(TypeError, combinations, 'abc') # missing r argument + self.assertRaises(TypeError, combinations, 'abc') # missing r argument self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments self.assertRaises(TypeError, combinations, None) # pool is not iterable self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative - self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big + self.assertEqual(list(combinations('abc', 32)), []) # r > n self.assertEqual(list(combinations(range(4), 3)), [(0,1,2), (0,1,3), (0,2,3), (1,2,3)]) @@ -87,6 +87,8 @@ class TestBasicOps(unittest.TestCase): 'Pure python version shown in the docs' pool = tuple(iterable) n = len(pool) + if r > n: + return indices = list(range(r)) yield tuple(pool[i] for i in indices) while 1: @@ -110,9 +112,9 @@ class TestBasicOps(unittest.TestCase): for n in range(7): values = [5*x-12 for x in range(n)] - for r in range(n+1): + for r in range(n+2): result = list(combinations(values, r)) - self.assertEqual(len(result), fact(n) / fact(r) / fact(n-r)) # right number of combs + self.assertEqual(len(result), 0 if r>n else fact(n) / fact(r) / fact(n-r)) # right number of combs self.assertEqual(len(result), len(set(result))) # no repeats self.assertEqual(result, sorted(result)) # lexicographic order for c in result: @@ -123,7 +125,7 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(list(c), [e for e in values if e in c]) # comb is a subsequence of the input iterable self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version - self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version + self.assertEqual(result, list(combinations2(values, r))) # matches second pure python version # Test implementation detail: tuple re-use self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1) @@ -134,7 +136,7 @@ class TestBasicOps(unittest.TestCase): self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments self.assertRaises(TypeError, permutations, None) # pool is not iterable self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative - self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big + self.assertEqual(list(permutations('abc', 32)), []) # r > n self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None self.assertEqual(list(permutations(range(3), 2)), [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)]) @@ -144,6 +146,8 @@ class TestBasicOps(unittest.TestCase): pool = tuple(iterable) n = len(pool) r = n if r is None else r + if r > n: + return indices = list(range(n)) cycles = list(range(n-r+1, n+1))[::-1] yield tuple(pool[i] for i in indices[:r]) @@ -172,9 +176,9 @@ class TestBasicOps(unittest.TestCase): for n in range(7): values = [5*x-12 for x in range(n)] - for r in range(n+1): + for r in range(n+2): result = list(permutations(values, r)) - self.assertEqual(len(result), fact(n) / fact(n-r)) # right number of perms + self.assertEqual(len(result), 0 if r>n else fact(n) / fact(n-r)) # right number of perms self.assertEqual(len(result), len(set(result))) # no repeats self.assertEqual(result, sorted(result)) # lexicographic order for p in result: @@ -182,7 +186,7 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(len(set(p)), r) # no duplicate elements self.assert_(all(e in values for e in p)) # elements taken from input iterable self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version - self.assertEqual(result, list(permutations2(values, r))) # matches first pure python version + self.assertEqual(result, list(permutations2(values, r))) # matches second pure python version if r == n: self.assertEqual(result, list(permutations(values, None))) # test r as None self.assertEqual(result, list(permutations(values))) # test default r @@ -1384,6 +1388,26 @@ perform as purported. >>> list(combinations_with_replacement('abc', 2)) [('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')] +>>> list(combinations_with_replacement('01', 3)) +[('0', '0', '0'), ('0', '0', '1'), ('0', '1', '1'), ('1', '1', '1')] + +>>> def combinations_with_replacement2(iterable, r): +... 'Alternate version that filters from product()' +... pool = tuple(iterable) +... n = len(pool) +... for indices in product(range(n), repeat=r): +... if sorted(indices) == list(indices): +... yield tuple(pool[i] for i in indices) + +>>> list(combinations_with_replacement('abc', 2)) == list(combinations_with_replacement2('abc', 2)) +True + +>>> list(combinations_with_replacement('01', 3)) == list(combinations_with_replacement2('01', 3)) +True + +>>> list(combinations_with_replacement('2310', 6)) == list(combinations_with_replacement2('2310', 6)) +True + >>> list(unique_everseen('AAAABBBCCDAABBB')) ['A', 'B', 'C', 'D'] diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index 3a095b6ea6c..8125dcb0642 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -1880,10 +1880,6 @@ combinations_new(PyTypeObject *type, PyObject *args, PyObject *kwds) PyErr_SetString(PyExc_ValueError, "r must be non-negative"); goto error; } - if (r > n) { - PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable"); - goto error; - } indices = PyMem_Malloc(r * sizeof(Py_ssize_t)); if (indices == NULL) { @@ -1903,7 +1899,7 @@ combinations_new(PyTypeObject *type, PyObject *args, PyObject *kwds) co->indices = indices; co->result = NULL; co->r = r; - co->stopped = 0; + co->stopped = r > n ? 1 : 0; return (PyObject *)co; @@ -2143,10 +2139,6 @@ permutations_new(PyTypeObject *type, PyObject *args, PyObject *kwds) PyErr_SetString(PyExc_ValueError, "r must be non-negative"); goto error; } - if (r > n) { - PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable"); - goto error; - } indices = PyMem_Malloc(n * sizeof(Py_ssize_t)); cycles = PyMem_Malloc(r * sizeof(Py_ssize_t)); @@ -2170,7 +2162,7 @@ permutations_new(PyTypeObject *type, PyObject *args, PyObject *kwds) po->cycles = cycles; po->result = NULL; po->r = r; - po->stopped = 0; + po->stopped = r > n ? 1 : 0; return (PyObject *)po;