diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index fe6131abfed..8a67cff60ce 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -56,6 +56,23 @@ def fact(n): return prod(range(1, n+1)) class TestBasicOps(unittest.TestCase): + + def test_accumulate(self): + self.assertEqual(list(accumulate(range(10))), # one positional arg + [0, 1, 3, 6, 10, 15, 21, 28, 36, 45]) + self.assertEqual(list(accumulate(range(10), 100)), # two positional args + [100, 101, 103, 106, 110, 115, 121, 128, 136, 145]) + self.assertEqual(list(accumulate(iterable=range(10), start=100)), # kw args + [100, 101, 103, 106, 110, 115, 121, 128, 136, 145]) + for typ in int, complex, Decimal, Fraction: # multiple types + self.assertEqual(list(accumulate(range(10), typ(0))), + list(map(typ, [0, 1, 3, 6, 10, 15, 21, 28, 36, 45]))) + self.assertEqual(list(accumulate([])), []) # empty iterable + self.assertRaises(TypeError, accumulate, range(10), 0, 5) # too many args + self.assertRaises(TypeError, accumulate) # too few args + self.assertRaises(TypeError, accumulate, range(10), x=7) # unexpected kwd args + self.assertRaises(TypeError, list, accumulate([1, []])) # args that don't add + def test_chain(self): def chain2(*iterables): @@ -932,6 +949,9 @@ class TestBasicOps(unittest.TestCase): class TestExamples(unittest.TestCase): + def test_accumlate(self): + self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15]) + def test_chain(self): self.assertEqual(''.join(chain('ABC', 'DEF')), 'ABCDEF') @@ -1019,6 +1039,10 @@ class TestGC(unittest.TestCase): next(iterator) del container, iterator + def test_accumulate(self): + a = [] + self.makecycle(accumulate([1,2,a,3]), a) + def test_chain(self): a = [] self.makecycle(chain(a), a) @@ -1188,6 +1212,17 @@ def L(seqn): class TestVariousIteratorArgs(unittest.TestCase): + def test_accumulate(self): + s = [1,2,3,4,5] + r = [1,3,6,10,15] + n = len(s) + for g in (G, I, Ig, L, R): + self.assertEqual(list(accumulate(g(s))), r) + self.assertEqual(list(accumulate(S(s))), []) + self.assertRaises(TypeError, accumulate, X(s)) + self.assertRaises(TypeError, accumulate, N(s)) + self.assertRaises(ZeroDivisionError, list, accumulate(E(s))) + def test_chain(self): for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)): for g in (G, I, Ig, S, L, R): diff --git a/Misc/NEWS b/Misc/NEWS index 7d77b20a328..494087c8837 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -46,6 +46,8 @@ Core and Builtins Library ------- +- Added itertools.accumulate(). + - Issue #4113: Added custom ``__repr__`` method to ``functools.partial``. Original patch by Daniel Urban. diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index d5336f24294..04bfffc5b0d 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -2584,6 +2584,146 @@ static PyTypeObject permutations_type = { PyObject_GC_Del, /* tp_free */ }; +/* accumulate object ************************************************************/ + +typedef struct { + PyObject_HEAD + PyObject *total; + PyObject *it; +} accumulateobject; + +static PyTypeObject accumulate_type; + +static PyObject * +accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + static char *kwargs[] = {"iterable", "start", NULL}; + PyObject *iterable; + PyObject *it; + PyObject *start = NULL; + accumulateobject *lz; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:accumulate", + kwargs, &iterable, &start)) + return NULL; + + /* Get iterator. */ + it = PyObject_GetIter(iterable); + if (it == NULL) + return NULL; + + /* Default start value */ + if (start == NULL) { + start = PyLong_FromLong(0); + if (start == NULL) { + Py_DECREF(it); + return NULL; + } + } else { + Py_INCREF(start); + } + + /* create accumulateobject structure */ + lz = (accumulateobject *)type->tp_alloc(type, 0); + if (lz == NULL) { + Py_DECREF(it); + Py_DECREF(start); + return NULL; + } + + lz->total = start; + lz->it = it; + return (PyObject *)lz; +} + +static void +accumulate_dealloc(accumulateobject *lz) +{ + PyObject_GC_UnTrack(lz); + Py_XDECREF(lz->total); + Py_XDECREF(lz->it); + Py_TYPE(lz)->tp_free(lz); +} + +static int +accumulate_traverse(accumulateobject *lz, visitproc visit, void *arg) +{ + Py_VISIT(lz->it); + Py_VISIT(lz->total); + return 0; +} + +static PyObject * +accumulate_next(accumulateobject *lz) +{ + PyObject *val, *oldtotal, *newtotal; + + val = PyIter_Next(lz->it); + if (val == NULL) + return NULL; + + newtotal = PyNumber_Add(lz->total, val); + Py_DECREF(val); + if (newtotal == NULL) + return NULL; + + oldtotal = lz->total; + lz->total = newtotal; + Py_DECREF(oldtotal); + + Py_INCREF(newtotal); + return newtotal; +} + +PyDoc_STRVAR(accumulate_doc, +"accumulate(iterable, start=0) --> accumulate object\n\ +\n\ +Return series of accumulated sums."); + +static PyTypeObject accumulate_type = { + PyVarObject_HEAD_INIT(NULL, 0) + "itertools.accumulate", /* tp_name */ + sizeof(accumulateobject), /* tp_basicsize */ + 0, /* tp_itemsize */ + /* methods */ + (destructor)accumulate_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + PyObject_GenericGetAttr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_BASETYPE, /* tp_flags */ + accumulate_doc, /* tp_doc */ + (traverseproc)accumulate_traverse, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + PyObject_SelfIter, /* tp_iter */ + (iternextfunc)accumulate_next, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + accumulate_new, /* tp_new */ + PyObject_GC_Del, /* tp_free */ +}; + /* compress object ************************************************************/ @@ -3496,6 +3636,7 @@ cycle(p) --> p0, p1, ... plast, p0, p1, ...\n\ repeat(elem [,n]) --> elem, elem, elem, ... endlessly or up to n times\n\ \n\ Iterators terminating on the shortest input sequence:\n\ +accumulate(p, start=0) --> p0, p0+p1, p0+p1+p2\n\ chain(p, q, ...) --> p0, p1, ... plast, q0, q1, ... \n\ compress(data, selectors) --> (d[0] if s[0]), (d[1] if s[1]), ...\n\ dropwhile(pred, seq) --> seq[n], seq[n+1], starting when pred fails\n\ @@ -3541,6 +3682,7 @@ PyInit_itertools(void) PyObject *m; char *name; PyTypeObject *typelist[] = { + &accumulate_type, &combinations_type, &cwr_type, &cycle_type,