Add itertools.combinations().

This commit is contained in:
Raymond Hettinger 2008-02-26 23:40:50 +00:00
parent 3ef2063ec8
commit 93e804da9c
4 changed files with 266 additions and 16 deletions

View File

@ -97,21 +97,21 @@ loops that truncate the stream.
def combinations(iterable, r):
pool = tuple(iterable)
if pool:
n = len(pool)
vec = range(r)
yield tuple(pool[i] for i in vec)
while 1:
for i in reversed(range(r)):
if vec[i] == i + n-r:
continue
vec[i] += 1
for j in range(i+1, r):
vec[j] = vec[j-1] + 1
yield tuple(pool[i] for i in vec)
break
else:
return
n = len(pool)
assert 0 <= r <= n
vec = range(r)
yield tuple(pool[i] for i in vec)
while 1:
for i in reversed(range(r)):
if vec[i] == i + n-r:
continue
vec[i] += 1
for j in range(i+1, r):
vec[j] = vec[j-1] + 1
yield tuple(pool[i] for i in vec)
break
else:
return
.. versionadded:: 2.6

View File

@ -40,6 +40,10 @@ def take(n, seq):
'Convenience function for partially consuming a long of infinite iterable'
return list(islice(seq, n))
def fact(n):
'Factorial'
return reduce(operator.mul, range(1, n+1), 1)
class TestBasicOps(unittest.TestCase):
def test_chain(self):
self.assertEqual(list(chain('abc', 'def')), list('abcdef'))
@ -48,6 +52,26 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(take(4, chain('abc', 'def')), list('abcd'))
self.assertRaises(TypeError, chain, 2, 3)
def test_combinations(self):
self.assertRaises(TypeError, combinations, 'abc') # missing r argument
self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments
self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative
self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big
self.assertEqual(list(combinations(range(4), 3)),
[(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
for n in range(6):
values = [5*x-12 for x in range(n)]
for r in range(n+1):
result = list(combinations(values, r))
self.assertEqual(len(result), fact(n) / fact(r) / fact(n-r)) # right number of combs
self.assertEqual(len(result), len(set(result))) # no repeats
self.assertEqual(result, sorted(result)) # lexicographic order
for c in result:
self.assertEqual(len(c), r) # r-length combinations
self.assertEqual(len(set(c)), r) # no duplicate elements
self.assertEqual(list(c), sorted(c)) # keep original ordering
self.assert_(all(e in values for e in c)) # elements taken from input iterable
def test_count(self):
self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)])
self.assertEqual(zip('abc',count(3)), [('a', 3), ('b', 4), ('c', 5)])

View File

@ -670,6 +670,8 @@ Library
- Added itertools.product() which forms the Cartesian product of
the input iterables.
- Added itertools.combinations().
- Patch #1541463: optimize performance of cgi.FieldStorage operations.
- Decimal is fully updated to the latest Decimal Specification (v1.66).

View File

@ -1982,6 +1982,229 @@ static PyTypeObject product_type = {
};
/* combinations object ************************************************************/
typedef struct {
PyObject_HEAD
PyObject *pool; /* input converted to a tuple */
Py_ssize_t *indices; /* one index per result element */
PyObject *result; /* most recently returned result tuple */
Py_ssize_t r; /* size of result tuple */
int stopped; /* set to 1 when the combinations iterator is exhausted */
} combinationsobject;
static PyTypeObject combinations_type;
static PyObject *
combinations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
combinationsobject *co;
Py_ssize_t n;
Py_ssize_t r;
PyObject *pool = NULL;
PyObject *iterable = NULL;
Py_ssize_t *indices = NULL;
Py_ssize_t i;
static char *kwargs[] = {"iterable", "r", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwds, "On:combinations", kwargs,
&iterable, &r))
return NULL;
pool = PySequence_Tuple(iterable);
if (pool == NULL)
goto error;
n = PyTuple_GET_SIZE(pool);
if (r < 0) {
PyErr_SetString(PyExc_ValueError, "r must be non-negative");
goto error;
}
if (r > n) {
PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable");
goto error;
}
indices = PyMem_Malloc(r * sizeof(Py_ssize_t));
if (indices == NULL) {
PyErr_NoMemory();
goto error;
}
for (i=0 ; i<r ; i++)
indices[i] = i;
/* create combinationsobject structure */
co = (combinationsobject *)type->tp_alloc(type, 0);
if (co == NULL)
goto error;
co->pool = pool;
co->indices = indices;
co->result = NULL;
co->r = r;
co->stopped = 0;
return (PyObject *)co;
error:
if (indices != NULL)
PyMem_Free(indices);
Py_XDECREF(pool);
return NULL;
}
static void
combinations_dealloc(combinationsobject *co)
{
PyObject_GC_UnTrack(co);
Py_XDECREF(co->pool);
Py_XDECREF(co->result);
PyMem_Free(co->indices);
Py_TYPE(co)->tp_free(co);
}
static int
combinations_traverse(combinationsobject *co, visitproc visit, void *arg)
{
Py_VISIT(co->pool);
Py_VISIT(co->result);
return 0;
}
static PyObject *
combinations_next(combinationsobject *co)
{
PyObject *elem;
PyObject *oldelem;
PyObject *pool = co->pool;
Py_ssize_t *indices = co->indices;
PyObject *result = co->result;
Py_ssize_t n = PyTuple_GET_SIZE(pool);
Py_ssize_t r = co->r;
Py_ssize_t i, j, index;
if (co->stopped)
return NULL;
if (result == NULL) {
/* On the first pass, initialize result tuple using the indices */
result = PyTuple_New(r);
if (result == NULL)
goto empty;
co->result = result;
for (i=0; i<r ; i++) {
index = indices[i];
elem = PyTuple_GET_ITEM(pool, index);
Py_INCREF(elem);
PyTuple_SET_ITEM(result, i, elem);
}
} else {
/* Copy the previous result tuple or re-use it if available */
if (Py_REFCNT(result) > 1) {
PyObject *old_result = result;
result = PyTuple_New(r);
if (result == NULL)
goto empty;
co->result = result;
for (i=0; i<r ; i++) {
elem = PyTuple_GET_ITEM(old_result, i);
Py_INCREF(elem);
PyTuple_SET_ITEM(result, i, elem);
}
Py_DECREF(old_result);
}
/* Now, we've got the only copy so we can update it in-place */
assert (Py_REFCNT(result) == 1);
/* Scan indices right-to-left until finding one that is not
at its maximum (i + n - r). */
for (i=r-1 ; i >= 0 && indices[i] == i+n-r ; i--)
;
/* If i is negative, then the indices are all at
their maximum value and we're done. */
if (i < 0)
goto empty;
/* Increment the current index which we know is not at its
maximum. Then move back to the right setting each index
to its lowest possible value (one higher than the index
to its left -- this maintains the sort order invariant). */
indices[i]++;
for (j=i+1 ; j<r ; j++)
indices[j] = indices[j-1] + 1;
/* Update the result tuple for the new indices
starting with i, the leftmost index that changed */
for ( ; i<r ; i++) {
index = indices[i];
elem = PyTuple_GET_ITEM(pool, index);
Py_INCREF(elem);
oldelem = PyTuple_GET_ITEM(result, i);
PyTuple_SET_ITEM(result, i, elem);
Py_DECREF(oldelem);
}
}
Py_INCREF(result);
return result;
empty:
co->stopped = 1;
return NULL;
}
PyDoc_STRVAR(combinations_doc,
"combinations(iterables) --> combinations object\n\
\n\
Return successive r-length combinations of elements in the iterable.\n\n\
combinations(range(4), 3) --> (0,1,2), (0,1,3), (0,2,3), (1,2,3)");
static PyTypeObject combinations_type = {
PyVarObject_HEAD_INIT(NULL, 0)
"itertools.combinations", /* tp_name */
sizeof(combinationsobject), /* tp_basicsize */
0, /* tp_itemsize */
/* methods */
(destructor)combinations_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
PyObject_GenericGetAttr, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_BASETYPE, /* tp_flags */
combinations_doc, /* tp_doc */
(traverseproc)combinations_traverse, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
(iternextfunc)combinations_next, /* tp_iternext */
0, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
combinations_new, /* tp_new */
PyObject_GC_Del, /* tp_free */
};
/* ifilter object ************************************************************/
typedef struct {
@ -3026,6 +3249,7 @@ inititertools(void)
PyObject *m;
char *name;
PyTypeObject *typelist[] = {
&combinations_type,
&cycle_type,
&dropwhile_type,
&takewhile_type,
@ -3038,7 +3262,7 @@ inititertools(void)
&count_type,
&izip_type,
&iziplongest_type,
&product_type,
&product_type,
&repeat_type,
&groupby_type,
NULL