From d081abc8be6192980c1fb06c837cc9e4fc373f55 Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Tue, 27 Jan 2009 02:58:49 +0000 Subject: [PATCH] Promote combinations_with_replacement() from a recipe to a regular itertool. --- Doc/library/collections.rst | 3 +- Doc/library/itertools.rst | 63 ++++++--- Lib/test/test_itertools.py | 116 +++++++++++------ Misc/NEWS | 3 +- Modules/itertoolsmodule.c | 253 +++++++++++++++++++++++++++++++++++- 5 files changed, 379 insertions(+), 59 deletions(-) diff --git a/Doc/library/collections.rst b/Doc/library/collections.rst index b358e38dbff..5b25b475b72 100644 --- a/Doc/library/collections.rst +++ b/Doc/library/collections.rst @@ -291,8 +291,7 @@ counts less than one:: Section 4.6.3, Exercise 19*\. * To enumerate all distinct multisets of a given size over a given set of - elements, see :func:`combinations_with_replacement` in the - :ref:`itertools-recipes` for itertools:: + elements, see :func:`itertools.combinations_with_replacement`. map(Counter, combinations_with_replacement('ABC', 2)) --> AA AB AC BB BC CC diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst index b7cd4318bf5..9aff478c65f 100644 --- a/Doc/library/itertools.rst +++ b/Doc/library/itertools.rst @@ -139,6 +139,53 @@ loops that truncate the stream. .. versionadded:: 2.6 +.. function:: combinations_with_replacement(iterable, r) + + Return *r* length subsequences of elements from the input *iterable* + allowing individual elements to be repeated more than once. + + Combinations are emitted in lexicographic sort order. So, if the + input *iterable* is sorted, the combination tuples will be produced + in sorted order. + + Elements are treated as unique based on their position, not on their + value. So if the input elements are unique, the generated combinations + will also be unique. + + Equivalent to:: + + def combinations_with_replacement(iterable, r): + # combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC + pool = tuple(iterable) + n = len(pool) + if not n and r: + return + indices = [0] * r + yield tuple(pool[i] for i in indices) + while 1: + for i in reversed(range(r)): + if indices[i] != n - 1: + break + else: + return + indices[i:] = [indices[i] + 1] * (r - i) + yield tuple(pool[i] for i in indices) + + The code for :func:`combinations_with_replacement` can be also expressed as + a subsequence of :func:`product` after filtering entries where the elements + are not in sorted order (according to their position in the input pool):: + + def combinations_with_replacement(iterable, r): + pool = tuple(iterable) + n = len(pool) + for indices in product(range(n), repeat=r): + if sorted(indices) == list(indices): + yield tuple(pool[i] for i in indices) + + The number of items returned is ``(n+r-1)! / r! / (n-1)!`` when ``n > 0``. + + .. versionadded:: 2.7 + .. function:: compress(data, selectors) Make an iterator that filters elements from *data* returning only those that @@ -691,22 +738,6 @@ which incur interpreter overhead. s = list(iterable) return chain.from_iterable(combinations(s, r) for r in range(len(s)+1)) - def combinations_with_replacement(iterable, r): - "combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC" - # number items returned: (n+r-1)! / r! / (n-1)! - pool = tuple(iterable) - n = len(pool) - indices = [0] * r - yield tuple(pool[i] for i in indices) - while 1: - for i in reversed(range(r)): - if indices[i] != n - 1: - break - else: - return - indices[i:] = [indices[i] + 1] * (r - i) - yield tuple(pool[i] for i in indices) - def unique_everseen(iterable, key=None): "List unique elements, preserving order. Remember all elements ever seen." # unique_everseen('AAAABBBCCDAABBB') --> A B C D diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 9b399c0f665..23a87654c7a 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -127,6 +127,76 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1) self.assertNotEqual(len(set(map(id, list(combinations('abcde', 3))))), 1) + def test_combinations_with_replacement(self): + cwr = combinations_with_replacement + self.assertRaises(TypeError, cwr, 'abc') # missing r argument + self.assertRaises(TypeError, cwr, 'abc', 2, 1) # too many arguments + self.assertRaises(TypeError, cwr, None) # pool is not iterable + self.assertRaises(ValueError, cwr, 'abc', -2) # r is negative + self.assertEqual(list(cwr('ABC', 2)), + [('A','A'), ('A','B'), ('A','C'), ('B','B'), ('B','C'), ('C','C')]) + + def cwr1(iterable, r): + 'Pure python version shown in the docs' + # number items returned: (n+r-1)! / r! / (n-1)! when n>0 + pool = tuple(iterable) + n = len(pool) + if not n and r: + return + indices = [0] * r + yield tuple(pool[i] for i in indices) + while 1: + for i in reversed(range(r)): + if indices[i] != n - 1: + break + else: + return + indices[i:] = [indices[i] + 1] * (r - i) + yield tuple(pool[i] for i in indices) + + def cwr2(iterable, r): + 'Pure python version shown in the docs' + pool = tuple(iterable) + n = len(pool) + for indices in product(range(n), repeat=r): + if sorted(indices) == list(indices): + yield tuple(pool[i] for i in indices) + + def numcombs(n, r): + if not n: + return 0 if r else 1 + return fact(n+r-1) / fact(r)/ fact(n-1) + + for n in range(7): + values = [5*x-12 for x in range(n)] + for r in range(n+2): + result = list(cwr(values, r)) + + self.assertEqual(len(result), numcombs(n, r)) # right number of combs + self.assertEqual(len(result), len(set(result))) # no repeats + self.assertEqual(result, sorted(result)) # lexicographic order + + regular_combs = list(combinations(values, r)) # compare to combs without replacement + if n == 0 or r <= 1: + self.assertEquals(result, regular_combs) # cases that should be identical + else: + self.assert_(set(result) >= set(regular_combs)) # rest should be supersets of regular combs + + for c in result: + self.assertEqual(len(c), r) # r-length combinations + noruns = [k for k,v in groupby(c)] # combo without consecutive repeats + self.assertEqual(len(noruns), len(set(noruns))) # no repeats other than consecutive + self.assertEqual(list(c), sorted(c)) # keep original ordering + self.assert_(all(e in values for e in c)) # elements taken from input iterable + self.assertEqual(noruns, + [e for e in values if e in c]) # comb is a subsequence of the input iterable + self.assertEqual(result, list(cwr1(values, r))) # matches first pure python version + self.assertEqual(result, list(cwr2(values, r))) # matches second pure python version + + # Test implementation detail: tuple re-use + self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1) + self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1) + def test_permutations(self): self.assertRaises(TypeError, permutations) # too few arguments self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments @@ -716,6 +786,10 @@ class TestExamples(unittest.TestCase): self.assertEqual(list(combinations(range(4), 3)), [(0,1,2), (0,1,3), (0,2,3), (1,2,3)]) + def test_combinations_with_replacement(self): + self.assertEqual(list(combinations_with_replacement('ABC', 2)), + [('A','A'), ('A','B'), ('A','C'), ('B','B'), ('B','C'), ('C','C')]) + def test_compress(self): self.assertEqual(list(compress('ABCDEF', [1,0,1,0,1,1])), list('ACEF')) @@ -799,6 +873,10 @@ class TestGC(unittest.TestCase): a = [] self.makecycle(combinations([1,2,a,3], 3), a) + def test_combinations_with_replacement(self): + a = [] + self.makecycle(combinations_with_replacement([1,2,a,3], 3), a) + def test_compress(self): a = [] self.makecycle(compress('ABCDEF', [1,0,1,0,1,0]), a) @@ -1291,21 +1369,6 @@ Samuele ... s = list(iterable) ... return chain.from_iterable(combinations(s, r) for r in range(len(s)+1)) ->>> def combinations_with_replacement(iterable, r): -... "combinations_with_replacement('ABC', 3) --> AA AB AC BB BC CC" -... pool = tuple(iterable) -... n = len(pool) -... indices = [0] * r -... yield tuple(pool[i] for i in indices) -... while 1: -... for i in reversed(range(r)): -... if indices[i] != n - 1: -... break -... else: -... return -... indices[i:] = [indices[i] + 1] * (r - i) -... yield tuple(pool[i] for i in indices) - >>> def unique_everseen(iterable, key=None): ... "List unique elements, preserving order. Remember all elements ever seen." ... # unique_everseen('AAAABBBCCDAABBB') --> A B C D @@ -1386,29 +1449,6 @@ perform as purported. >>> list(powerset([1,2,3])) [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] ->>> list(combinations_with_replacement('abc', 2)) -[('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')] - ->>> list(combinations_with_replacement('01', 3)) -[('0', '0', '0'), ('0', '0', '1'), ('0', '1', '1'), ('1', '1', '1')] - ->>> def combinations_with_replacement2(iterable, r): -... 'Alternate version that filters from product()' -... pool = tuple(iterable) -... n = len(pool) -... for indices in product(range(n), repeat=r): -... if sorted(indices) == list(indices): -... yield tuple(pool[i] for i in indices) - ->>> list(combinations_with_replacement('abc', 2)) == list(combinations_with_replacement2('abc', 2)) -True - ->>> list(combinations_with_replacement('01', 3)) == list(combinations_with_replacement2('01', 3)) -True - ->>> list(combinations_with_replacement('2310', 6)) == list(combinations_with_replacement2('2310', 6)) -True - >>> list(unique_everseen('AAAABBBCCDAABBB')) ['A', 'B', 'C', 'D'] diff --git a/Misc/NEWS b/Misc/NEWS index da016c1d18d..bb1aeac6b4d 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -153,7 +153,8 @@ Library - Issue #4863: distutils.mwerkscompiler has been removed. -- Added a new function: itertools.compress(). +- Added a new itertools functions: combinations_with_replacement() + and compress(). - Fix and properly document the multiprocessing module's logging support, expose the internal levels and provide proper usage diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index f66d052e907..221dbe5d15b 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -1862,7 +1862,8 @@ product_dealloc(productobject *lz) PyObject_GC_UnTrack(lz); Py_XDECREF(lz->pools); Py_XDECREF(lz->result); - PyMem_Free(lz->indices); + if (lz->indices != NULL) + PyMem_Free(lz->indices); Py_TYPE(lz)->tp_free(lz); } @@ -2090,7 +2091,8 @@ combinations_dealloc(combinationsobject *co) PyObject_GC_UnTrack(co); Py_XDECREF(co->pool); Py_XDECREF(co->result); - PyMem_Free(co->indices); + if (co->indices != NULL) + PyMem_Free(co->indices); Py_TYPE(co)->tp_free(co); } @@ -2239,6 +2241,252 @@ static PyTypeObject combinations_type = { }; +/* combinations with replacement object *******************************************/ + +/* Equivalent to: + + def combinations_with_replacement(iterable, r): + "combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC" + # number items returned: (n+r-1)! / r! / (n-1)! + pool = tuple(iterable) + n = len(pool) + indices = [0] * r + yield tuple(pool[i] for i in indices) + while 1: + for i in reversed(range(r)): + if indices[i] != n - 1: + break + else: + return + indices[i:] = [indices[i] + 1] * (r - i) + yield tuple(pool[i] for i in indices) + + def combinations_with_replacement2(iterable, r): + 'Alternate version that filters from product()' + pool = tuple(iterable) + n = len(pool) + for indices in product(range(n), repeat=r): + if sorted(indices) == list(indices): + yield tuple(pool[i] for i in indices) +*/ +typedef struct { + PyObject_HEAD + PyObject *pool; /* input converted to a tuple */ + Py_ssize_t *indices; /* one index per result element */ + PyObject *result; /* most recently returned result tuple */ + Py_ssize_t r; /* size of result tuple */ + int stopped; /* set to 1 when the cwr iterator is exhausted */ +} cwrobject; + +static PyTypeObject cwr_type; + +static PyObject * +cwr_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + cwrobject *co; + Py_ssize_t n; + Py_ssize_t r; + PyObject *pool = NULL; + PyObject *iterable = NULL; + Py_ssize_t *indices = NULL; + Py_ssize_t i; + static char *kwargs[] = {"iterable", "r", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "On:combinations_with_replacement", kwargs, + &iterable, &r)) + return NULL; + + pool = PySequence_Tuple(iterable); + if (pool == NULL) + goto error; + n = PyTuple_GET_SIZE(pool); + if (r < 0) { + PyErr_SetString(PyExc_ValueError, "r must be non-negative"); + goto error; + } + + indices = PyMem_Malloc(r * sizeof(Py_ssize_t)); + if (indices == NULL) { + PyErr_NoMemory(); + goto error; + } + + for (i=0 ; itp_alloc(type, 0); + if (co == NULL) + goto error; + + co->pool = pool; + co->indices = indices; + co->result = NULL; + co->r = r; + co->stopped = !n && r; + + return (PyObject *)co; + +error: + if (indices != NULL) + PyMem_Free(indices); + Py_XDECREF(pool); + return NULL; +} + +static void +cwr_dealloc(cwrobject *co) +{ + PyObject_GC_UnTrack(co); + Py_XDECREF(co->pool); + Py_XDECREF(co->result); + if (co->indices != NULL) + PyMem_Free(co->indices); + Py_TYPE(co)->tp_free(co); +} + +static int +cwr_traverse(cwrobject *co, visitproc visit, void *arg) +{ + Py_VISIT(co->pool); + Py_VISIT(co->result); + return 0; +} + +static PyObject * +cwr_next(cwrobject *co) +{ + PyObject *elem; + PyObject *oldelem; + PyObject *pool = co->pool; + Py_ssize_t *indices = co->indices; + PyObject *result = co->result; + Py_ssize_t n = PyTuple_GET_SIZE(pool); + Py_ssize_t r = co->r; + Py_ssize_t i, j, index; + + if (co->stopped) + return NULL; + + if (result == NULL) { + /* On the first pass, initialize result tuple using the indices */ + result = PyTuple_New(r); + if (result == NULL) + goto empty; + co->result = result; + for (i=0; i 1) { + PyObject *old_result = result; + result = PyTuple_New(r); + if (result == NULL) + goto empty; + co->result = result; + for (i=0; i= 0 && indices[i] == n-1; i--) + ; + + /* If i is negative, then the indices are all at + their maximum value and we're done. */ + if (i < 0) + goto empty; + + /* Increment the current index which we know is not at its + maximum. Then set all to the right to the same value. */ + indices[i]++; + for (j=i+1 ; jstopped = 1; + return NULL; +} + +PyDoc_STRVAR(cwr_doc, +"combinations_with_replacement(iterable[, r]) --> combinations_with_replacement object\n\ +\n\ +Return successive r-length combinations of elements in the iterable\n\ +allowing individual elements to have successive repeats.\n\ +combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC"); + +static PyTypeObject cwr_type = { + PyVarObject_HEAD_INIT(NULL, 0) + "itertools.combinations_with_replacement", /* tp_name */ + sizeof(cwrobject), /* tp_basicsize */ + 0, /* tp_itemsize */ + /* methods */ + (destructor)cwr_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + PyObject_GenericGetAttr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_BASETYPE, /* tp_flags */ + cwr_doc, /* tp_doc */ + (traverseproc)cwr_traverse, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + PyObject_SelfIter, /* tp_iter */ + (iternextfunc)cwr_next, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + cwr_new, /* tp_new */ + PyObject_GC_Del, /* tp_free */ +}; + + /* permutations object ************************************************************ def permutations(iterable, r=None): @@ -3701,6 +3949,7 @@ inititertools(void) char *name; PyTypeObject *typelist[] = { &combinations_type, + &cwr_type, &cycle_type, &dropwhile_type, &takewhile_type,