Add itertools.accumulate().

This commit is contained in:
Raymond Hettinger 2010-12-01 22:48:00 +00:00
parent 2f9a77a389
commit 482ba77245
3 changed files with 179 additions and 0 deletions

View File

@ -56,6 +56,23 @@ def fact(n):
return prod(range(1, n+1))
class TestBasicOps(unittest.TestCase):
def test_accumulate(self):
self.assertEqual(list(accumulate(range(10))), # one positional arg
[0, 1, 3, 6, 10, 15, 21, 28, 36, 45])
self.assertEqual(list(accumulate(range(10), 100)), # two positional args
[100, 101, 103, 106, 110, 115, 121, 128, 136, 145])
self.assertEqual(list(accumulate(iterable=range(10), start=100)), # kw args
[100, 101, 103, 106, 110, 115, 121, 128, 136, 145])
for typ in int, complex, Decimal, Fraction: # multiple types
self.assertEqual(list(accumulate(range(10), typ(0))),
list(map(typ, [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])))
self.assertEqual(list(accumulate([])), []) # empty iterable
self.assertRaises(TypeError, accumulate, range(10), 0, 5) # too many args
self.assertRaises(TypeError, accumulate) # too few args
self.assertRaises(TypeError, accumulate, range(10), x=7) # unexpected kwd args
self.assertRaises(TypeError, list, accumulate([1, []])) # args that don't add
def test_chain(self):
def chain2(*iterables):
@ -932,6 +949,9 @@ class TestBasicOps(unittest.TestCase):
class TestExamples(unittest.TestCase):
def test_accumlate(self):
self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15])
def test_chain(self):
self.assertEqual(''.join(chain('ABC', 'DEF')), 'ABCDEF')
@ -1019,6 +1039,10 @@ class TestGC(unittest.TestCase):
next(iterator)
del container, iterator
def test_accumulate(self):
a = []
self.makecycle(accumulate([1,2,a,3]), a)
def test_chain(self):
a = []
self.makecycle(chain(a), a)
@ -1188,6 +1212,17 @@ def L(seqn):
class TestVariousIteratorArgs(unittest.TestCase):
def test_accumulate(self):
s = [1,2,3,4,5]
r = [1,3,6,10,15]
n = len(s)
for g in (G, I, Ig, L, R):
self.assertEqual(list(accumulate(g(s))), r)
self.assertEqual(list(accumulate(S(s))), [])
self.assertRaises(TypeError, accumulate, X(s))
self.assertRaises(TypeError, accumulate, N(s))
self.assertRaises(ZeroDivisionError, list, accumulate(E(s)))
def test_chain(self):
for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
for g in (G, I, Ig, S, L, R):

View File

@ -46,6 +46,8 @@ Core and Builtins
Library
-------
- Added itertools.accumulate().
- Issue #4113: Added custom ``__repr__`` method to ``functools.partial``.
Original patch by Daniel Urban.

View File

@ -2584,6 +2584,146 @@ static PyTypeObject permutations_type = {
PyObject_GC_Del, /* tp_free */
};
/* accumulate object ************************************************************/
typedef struct {
PyObject_HEAD
PyObject *total;
PyObject *it;
} accumulateobject;
static PyTypeObject accumulate_type;
static PyObject *
accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
static char *kwargs[] = {"iterable", "start", NULL};
PyObject *iterable;
PyObject *it;
PyObject *start = NULL;
accumulateobject *lz;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:accumulate",
kwargs, &iterable, &start))
return NULL;
/* Get iterator. */
it = PyObject_GetIter(iterable);
if (it == NULL)
return NULL;
/* Default start value */
if (start == NULL) {
start = PyLong_FromLong(0);
if (start == NULL) {
Py_DECREF(it);
return NULL;
}
} else {
Py_INCREF(start);
}
/* create accumulateobject structure */
lz = (accumulateobject *)type->tp_alloc(type, 0);
if (lz == NULL) {
Py_DECREF(it);
Py_DECREF(start);
return NULL;
}
lz->total = start;
lz->it = it;
return (PyObject *)lz;
}
static void
accumulate_dealloc(accumulateobject *lz)
{
PyObject_GC_UnTrack(lz);
Py_XDECREF(lz->total);
Py_XDECREF(lz->it);
Py_TYPE(lz)->tp_free(lz);
}
static int
accumulate_traverse(accumulateobject *lz, visitproc visit, void *arg)
{
Py_VISIT(lz->it);
Py_VISIT(lz->total);
return 0;
}
static PyObject *
accumulate_next(accumulateobject *lz)
{
PyObject *val, *oldtotal, *newtotal;
val = PyIter_Next(lz->it);
if (val == NULL)
return NULL;
newtotal = PyNumber_Add(lz->total, val);
Py_DECREF(val);
if (newtotal == NULL)
return NULL;
oldtotal = lz->total;
lz->total = newtotal;
Py_DECREF(oldtotal);
Py_INCREF(newtotal);
return newtotal;
}
PyDoc_STRVAR(accumulate_doc,
"accumulate(iterable, start=0) --> accumulate object\n\
\n\
Return series of accumulated sums.");
static PyTypeObject accumulate_type = {
PyVarObject_HEAD_INIT(NULL, 0)
"itertools.accumulate", /* tp_name */
sizeof(accumulateobject), /* tp_basicsize */
0, /* tp_itemsize */
/* methods */
(destructor)accumulate_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
PyObject_GenericGetAttr, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_BASETYPE, /* tp_flags */
accumulate_doc, /* tp_doc */
(traverseproc)accumulate_traverse, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
(iternextfunc)accumulate_next, /* tp_iternext */
0, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
accumulate_new, /* tp_new */
PyObject_GC_Del, /* tp_free */
};
/* compress object ************************************************************/
@ -3496,6 +3636,7 @@ cycle(p) --> p0, p1, ... plast, p0, p1, ...\n\
repeat(elem [,n]) --> elem, elem, elem, ... endlessly or up to n times\n\
\n\
Iterators terminating on the shortest input sequence:\n\
accumulate(p, start=0) --> p0, p0+p1, p0+p1+p2\n\
chain(p, q, ...) --> p0, p1, ... plast, q0, q1, ... \n\
compress(data, selectors) --> (d[0] if s[0]), (d[1] if s[1]), ...\n\
dropwhile(pred, seq) --> seq[n], seq[n+1], starting when pred fails\n\
@ -3541,6 +3682,7 @@ PyInit_itertools(void)
PyObject *m;
char *name;
PyTypeObject *typelist[] = {
&accumulate_type,
&combinations_type,
&cwr_type,
&cycle_type,