Forward port r69001: itertools.combinations_with_replacement().

This commit is contained in:
Raymond Hettinger 2009-01-27 04:20:44 +00:00
parent dd1b33a2ed
commit d07d939c5e
5 changed files with 379 additions and 59 deletions

View File

@ -279,8 +279,7 @@ counts less than one::
Section 4.6.3, Exercise 19*\.
* To enumerate all distinct multisets of a given size over a given set of
elements, see :func:`combinations_with_replacement` in the
:ref:`itertools-recipes` for itertools::
elements, see :func:`itertools.combinations_with_replacement`.
map(Counter, combinations_with_replacement('ABC', 2)) --> AA AB AC BB BC CC

View File

@ -133,6 +133,53 @@ loops that truncate the stream.
The number of items returned is ``n! / r! / (n-r)!`` when ``0 <= r <= n``
or zero when ``r > n``.
.. function:: combinations_with_replacement(iterable, r)
Return *r* length subsequences of elements from the input *iterable*
allowing individual elements to be repeated more than once.
Combinations are emitted in lexicographic sort order. So, if the
input *iterable* is sorted, the combination tuples will be produced
in sorted order.
Elements are treated as unique based on their position, not on their
value. So if the input elements are unique, the generated combinations
will also be unique.
Equivalent to::
def combinations_with_replacement(iterable, r):
# combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC
pool = tuple(iterable)
n = len(pool)
if not n and r:
return
indices = [0] * r
yield tuple(pool[i] for i in indices)
while 1:
for i in reversed(range(r)):
if indices[i] != n - 1:
break
else:
return
indices[i:] = [indices[i] + 1] * (r - i)
yield tuple(pool[i] for i in indices)
The code for :func:`combinations_with_replacement` can be also expressed as
a subsequence of :func:`product` after filtering entries where the elements
are not in sorted order (according to their position in the input pool)::
def combinations_with_replacement(iterable, r):
pool = tuple(iterable)
n = len(pool)
for indices in product(range(n), repeat=r):
if sorted(indices) == list(indices):
yield tuple(pool[i] for i in indices)
The number of items returned is ``(n+r-1)! / r! / (n-1)!`` when ``n > 0``.
.. versionadded:: 2.7
.. function:: compress(data, selectors)
Make an iterator that filters elements from *data* returning only those that
@ -608,22 +655,6 @@ which incur interpreter overhead.
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
def combinations_with_replacement(iterable, r):
"combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC"
# number items returned: (n+r-1)! / r! / (n-1)!
pool = tuple(iterable)
n = len(pool)
indices = [0] * r
yield tuple(pool[i] for i in indices)
while True:
for i in reversed(range(r)):
if indices[i] != n - 1:
break
else:
return
indices[i:] = [indices[i] + 1] * (r - i)
yield tuple(pool[i] for i in indices)
def unique_everseen(iterable, key=None):
"List unique elements, preserving order. Remember all elements ever seen."
# unique_everseen('AAAABBBCCDAABBB') --> A B C D

View File

@ -131,6 +131,76 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1)
self.assertNotEqual(len(set(map(id, list(combinations('abcde', 3))))), 1)
def test_combinations_with_replacement(self):
cwr = combinations_with_replacement
self.assertRaises(TypeError, cwr, 'abc') # missing r argument
self.assertRaises(TypeError, cwr, 'abc', 2, 1) # too many arguments
self.assertRaises(TypeError, cwr, None) # pool is not iterable
self.assertRaises(ValueError, cwr, 'abc', -2) # r is negative
self.assertEqual(list(cwr('ABC', 2)),
[('A','A'), ('A','B'), ('A','C'), ('B','B'), ('B','C'), ('C','C')])
def cwr1(iterable, r):
'Pure python version shown in the docs'
# number items returned: (n+r-1)! / r! / (n-1)! when n>0
pool = tuple(iterable)
n = len(pool)
if not n and r:
return
indices = [0] * r
yield tuple(pool[i] for i in indices)
while 1:
for i in reversed(range(r)):
if indices[i] != n - 1:
break
else:
return
indices[i:] = [indices[i] + 1] * (r - i)
yield tuple(pool[i] for i in indices)
def cwr2(iterable, r):
'Pure python version shown in the docs'
pool = tuple(iterable)
n = len(pool)
for indices in product(range(n), repeat=r):
if sorted(indices) == list(indices):
yield tuple(pool[i] for i in indices)
def numcombs(n, r):
if not n:
return 0 if r else 1
return fact(n+r-1) / fact(r)/ fact(n-1)
for n in range(7):
values = [5*x-12 for x in range(n)]
for r in range(n+2):
result = list(cwr(values, r))
self.assertEqual(len(result), numcombs(n, r)) # right number of combs
self.assertEqual(len(result), len(set(result))) # no repeats
self.assertEqual(result, sorted(result)) # lexicographic order
regular_combs = list(combinations(values, r)) # compare to combs without replacement
if n == 0 or r <= 1:
self.assertEquals(result, regular_combs) # cases that should be identical
else:
self.assert_(set(result) >= set(regular_combs)) # rest should be supersets of regular combs
for c in result:
self.assertEqual(len(c), r) # r-length combinations
noruns = [k for k,v in groupby(c)] # combo without consecutive repeats
self.assertEqual(len(noruns), len(set(noruns))) # no repeats other than consecutive
self.assertEqual(list(c), sorted(c)) # keep original ordering
self.assert_(all(e in values for e in c)) # elements taken from input iterable
self.assertEqual(noruns,
[e for e in values if e in c]) # comb is a subsequence of the input iterable
self.assertEqual(result, list(cwr1(values, r))) # matches first pure python version
self.assertEqual(result, list(cwr2(values, r))) # matches second pure python version
# Test implementation detail: tuple re-use
self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1)
self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1)
def test_permutations(self):
self.assertRaises(TypeError, permutations) # too few arguments
self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
@ -730,6 +800,10 @@ class TestExamples(unittest.TestCase):
self.assertEqual(list(combinations(range(4), 3)),
[(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
def test_combinations_with_replacement(self):
self.assertEqual(list(combinations_with_replacement('ABC', 2)),
[('A','A'), ('A','B'), ('A','C'), ('B','B'), ('B','C'), ('C','C')])
def test_compress(self):
self.assertEqual(list(compress('ABCDEF', [1,0,1,0,1,1])), list('ACEF'))
@ -813,6 +887,10 @@ class TestGC(unittest.TestCase):
a = []
self.makecycle(combinations([1,2,a,3], 3), a)
def test_combinations_with_replacement(self):
a = []
self.makecycle(combinations_with_replacement([1,2,a,3], 3), a)
def test_compress(self):
a = []
self.makecycle(compress('ABCDEF', [1,0,1,0,1,0]), a)
@ -1312,21 +1390,6 @@ Samuele
... s = list(iterable)
... return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
>>> def combinations_with_replacement(iterable, r):
... "combinations_with_replacement('ABC', 3) --> AA AB AC BB BC CC"
... pool = tuple(iterable)
... n = len(pool)
... indices = [0] * r
... yield tuple(pool[i] for i in indices)
... while 1:
... for i in reversed(range(r)):
... if indices[i] != n - 1:
... break
... else:
... return
... indices[i:] = [indices[i] + 1] * (r - i)
... yield tuple(pool[i] for i in indices)
>>> def unique_everseen(iterable, key=None):
... "List unique elements, preserving order. Remember all elements ever seen."
... # unique_everseen('AAAABBBCCDAABBB') --> A B C D
@ -1407,29 +1470,6 @@ perform as purported.
>>> list(powerset([1,2,3]))
[(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
>>> list(combinations_with_replacement('abc', 2))
[('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')]
>>> list(combinations_with_replacement('01', 3))
[('0', '0', '0'), ('0', '0', '1'), ('0', '1', '1'), ('1', '1', '1')]
>>> def combinations_with_replacement2(iterable, r):
... 'Alternate version that filters from product()'
... pool = tuple(iterable)
... n = len(pool)
... for indices in product(range(n), repeat=r):
... if sorted(indices) == list(indices):
... yield tuple(pool[i] for i in indices)
>>> list(combinations_with_replacement('abc', 2)) == list(combinations_with_replacement2('abc', 2))
True
>>> list(combinations_with_replacement('01', 3)) == list(combinations_with_replacement2('01', 3))
True
>>> list(combinations_with_replacement('2310', 6)) == list(combinations_with_replacement2('2310', 6))
True
>>> list(unique_everseen('AAAABBBCCDAABBB'))
['A', 'B', 'C', 'D']

View File

@ -153,7 +153,8 @@ Library
- Issue #4863: distutils.mwerkscompiler has been removed.
- Added a new function: itertools.compress().
- Added a new itertools functions: combinations_with_replacement()
and compress().
- Fix and properly document the multiprocessing module's logging
support, expose the internal levels and provide proper usage

View File

@ -1683,7 +1683,8 @@ product_dealloc(productobject *lz)
PyObject_GC_UnTrack(lz);
Py_XDECREF(lz->pools);
Py_XDECREF(lz->result);
PyMem_Free(lz->indices);
if (lz->indices != NULL)
PyMem_Free(lz->indices);
Py_TYPE(lz)->tp_free(lz);
}
@ -1911,7 +1912,8 @@ combinations_dealloc(combinationsobject *co)
PyObject_GC_UnTrack(co);
Py_XDECREF(co->pool);
Py_XDECREF(co->result);
PyMem_Free(co->indices);
if (co->indices != NULL)
PyMem_Free(co->indices);
Py_TYPE(co)->tp_free(co);
}
@ -2060,6 +2062,252 @@ static PyTypeObject combinations_type = {
};
/* combinations with replacement object *******************************************/
/* Equivalent to:
def combinations_with_replacement(iterable, r):
"combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC"
# number items returned: (n+r-1)! / r! / (n-1)!
pool = tuple(iterable)
n = len(pool)
indices = [0] * r
yield tuple(pool[i] for i in indices)
while 1:
for i in reversed(range(r)):
if indices[i] != n - 1:
break
else:
return
indices[i:] = [indices[i] + 1] * (r - i)
yield tuple(pool[i] for i in indices)
def combinations_with_replacement2(iterable, r):
'Alternate version that filters from product()'
pool = tuple(iterable)
n = len(pool)
for indices in product(range(n), repeat=r):
if sorted(indices) == list(indices):
yield tuple(pool[i] for i in indices)
*/
typedef struct {
PyObject_HEAD
PyObject *pool; /* input converted to a tuple */
Py_ssize_t *indices; /* one index per result element */
PyObject *result; /* most recently returned result tuple */
Py_ssize_t r; /* size of result tuple */
int stopped; /* set to 1 when the cwr iterator is exhausted */
} cwrobject;
static PyTypeObject cwr_type;
static PyObject *
cwr_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
cwrobject *co;
Py_ssize_t n;
Py_ssize_t r;
PyObject *pool = NULL;
PyObject *iterable = NULL;
Py_ssize_t *indices = NULL;
Py_ssize_t i;
static char *kwargs[] = {"iterable", "r", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwds, "On:combinations_with_replacement", kwargs,
&iterable, &r))
return NULL;
pool = PySequence_Tuple(iterable);
if (pool == NULL)
goto error;
n = PyTuple_GET_SIZE(pool);
if (r < 0) {
PyErr_SetString(PyExc_ValueError, "r must be non-negative");
goto error;
}
indices = PyMem_Malloc(r * sizeof(Py_ssize_t));
if (indices == NULL) {
PyErr_NoMemory();
goto error;
}
for (i=0 ; i<r ; i++)
indices[i] = 0;
/* create cwrobject structure */
co = (cwrobject *)type->tp_alloc(type, 0);
if (co == NULL)
goto error;
co->pool = pool;
co->indices = indices;
co->result = NULL;
co->r = r;
co->stopped = !n && r;
return (PyObject *)co;
error:
if (indices != NULL)
PyMem_Free(indices);
Py_XDECREF(pool);
return NULL;
}
static void
cwr_dealloc(cwrobject *co)
{
PyObject_GC_UnTrack(co);
Py_XDECREF(co->pool);
Py_XDECREF(co->result);
if (co->indices != NULL)
PyMem_Free(co->indices);
Py_TYPE(co)->tp_free(co);
}
static int
cwr_traverse(cwrobject *co, visitproc visit, void *arg)
{
Py_VISIT(co->pool);
Py_VISIT(co->result);
return 0;
}
static PyObject *
cwr_next(cwrobject *co)
{
PyObject *elem;
PyObject *oldelem;
PyObject *pool = co->pool;
Py_ssize_t *indices = co->indices;
PyObject *result = co->result;
Py_ssize_t n = PyTuple_GET_SIZE(pool);
Py_ssize_t r = co->r;
Py_ssize_t i, j, index;
if (co->stopped)
return NULL;
if (result == NULL) {
/* On the first pass, initialize result tuple using the indices */
result = PyTuple_New(r);
if (result == NULL)
goto empty;
co->result = result;
for (i=0; i<r ; i++) {
index = indices[i];
elem = PyTuple_GET_ITEM(pool, index);
Py_INCREF(elem);
PyTuple_SET_ITEM(result, i, elem);
}
} else {
/* Copy the previous result tuple or re-use it if available */
if (Py_REFCNT(result) > 1) {
PyObject *old_result = result;
result = PyTuple_New(r);
if (result == NULL)
goto empty;
co->result = result;
for (i=0; i<r ; i++) {
elem = PyTuple_GET_ITEM(old_result, i);
Py_INCREF(elem);
PyTuple_SET_ITEM(result, i, elem);
}
Py_DECREF(old_result);
}
/* Now, we've got the only copy so we can update it in-place CPython's
empty tuple is a singleton and cached in PyTuple's freelist. */
assert(r == 0 || Py_REFCNT(result) == 1);
/* Scan indices right-to-left until finding one that is not
* at its maximum (n-1). */
for (i=r-1 ; i >= 0 && indices[i] == n-1; i--)
;
/* If i is negative, then the indices are all at
their maximum value and we're done. */
if (i < 0)
goto empty;
/* Increment the current index which we know is not at its
maximum. Then set all to the right to the same value. */
indices[i]++;
for (j=i+1 ; j<r ; j++)
indices[j] = indices[j-1];
/* Update the result tuple for the new indices
starting with i, the leftmost index that changed */
for ( ; i<r ; i++) {
index = indices[i];
elem = PyTuple_GET_ITEM(pool, index);
Py_INCREF(elem);
oldelem = PyTuple_GET_ITEM(result, i);
PyTuple_SET_ITEM(result, i, elem);
Py_DECREF(oldelem);
}
}
Py_INCREF(result);
return result;
empty:
co->stopped = 1;
return NULL;
}
PyDoc_STRVAR(cwr_doc,
"combinations_with_replacement(iterable[, r]) --> combinations_with_replacement object\n\
\n\
Return successive r-length combinations of elements in the iterable\n\
allowing individual elements to have successive repeats.\n\
combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC");
static PyTypeObject cwr_type = {
PyVarObject_HEAD_INIT(NULL, 0)
"itertools.combinations_with_replacement", /* tp_name */
sizeof(cwrobject), /* tp_basicsize */
0, /* tp_itemsize */
/* methods */
(destructor)cwr_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
PyObject_GenericGetAttr, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_BASETYPE, /* tp_flags */
cwr_doc, /* tp_doc */
(traverseproc)cwr_traverse, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
(iternextfunc)cwr_next, /* tp_iternext */
0, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
cwr_new, /* tp_new */
PyObject_GC_Del, /* tp_free */
};
/* permutations object ************************************************************
def permutations(iterable, r=None):
@ -3191,6 +3439,7 @@ PyInit_itertools(void)
char *name;
PyTypeObject *typelist[] = {
&combinations_type,
&cwr_type,
&cycle_type,
&dropwhile_type,
&takewhile_type,