Add optional *func* argument to itertools.accumulate().

This commit is contained in:
Raymond Hettinger 2011-03-27 18:52:10 -07:00
parent af88d86699
commit 5d44613e3b
4 changed files with 56 additions and 10 deletions

View File

@ -46,7 +46,7 @@ Iterator Arguments Results
==================== ============================ ================================================= ============================================================= ==================== ============================ ================================================= =============================================================
Iterator Arguments Results Example Iterator Arguments Results Example
==================== ============================ ================================================= ============================================================= ==================== ============================ ================================================= =============================================================
:func:`accumulate` p p0, p0+p1, p0+p1+p2, ... ``accumulate([1,2,3,4,5]) --> 1 3 6 10 15`` :func:`accumulate` p [,func] p0, p0+p1, p0+p1+p2, ... ``accumulate([1,2,3,4,5]) --> 1 3 6 10 15``
:func:`chain` p, q, ... p0, p1, ... plast, q0, q1, ... ``chain('ABC', 'DEF') --> A B C D E F`` :func:`chain` p, q, ... p0, p1, ... plast, q0, q1, ... ``chain('ABC', 'DEF') --> A B C D E F``
:func:`compress` data, selectors (d[0] if s[0]), (d[1] if s[1]), ... ``compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F`` :func:`compress` data, selectors (d[0] if s[0]), (d[1] if s[1]), ... ``compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F``
:func:`dropwhile` pred, seq seq[n], seq[n+1], starting when pred fails ``dropwhile(lambda x: x<5, [1,4,6,4,1]) --> 6 4 1`` :func:`dropwhile` pred, seq seq[n], seq[n+1], starting when pred fails ``dropwhile(lambda x: x<5, [1,4,6,4,1]) --> 6 4 1``
@ -84,23 +84,46 @@ The following module functions all construct and return iterators. Some provide
streams of infinite length, so they should only be accessed by functions or streams of infinite length, so they should only be accessed by functions or
loops that truncate the stream. loops that truncate the stream.
.. function:: accumulate(iterable) .. function:: accumulate(iterable[, func])
Make an iterator that returns accumulated sums. Elements may be any addable Make an iterator that returns accumulated sums. Elements may be any addable
type including :class:`Decimal` or :class:`Fraction`. Equivalent to:: type including :class:`Decimal` or :class:`Fraction`. If the optional
*func* argument is supplied, it should be a function of two arguments
and it will be used instead of addition.
def accumulate(iterable): Equivalent to::
def accumulate(iterable, func=operator.add):
'Return running totals' 'Return running totals'
# accumulate([1,2,3,4,5]) --> 1 3 6 10 15 # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
# accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
it = iter(iterable) it = iter(iterable)
total = next(it) total = next(it)
yield total yield total
for element in it: for element in it:
total = total + element total = func(total, element)
yield total yield total
Uses for the *func* argument include :func:`min` for a running minimum,
:func:`max` for a running maximum, and :func:`operator.mul` for a running
product::
>>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8]
>>> list(accumulate(data, operator.mul)) # running product
[3, 12, 72, 144, 144, 1296, 0, 0, 0, 0]
>>> list(accumulate(data, max)) # running maximum
[3, 4, 6, 6, 6, 9, 9, 9, 9, 9]
# Amortize a 5% loan of 1000 with 4 annual payments of 90
>>> cashflows = [1000, -90, -90, -90, -90]
>>> list(accumulate(cashflows, lambda bal, pmt: bal*1.05 + pmt))
[1000, 960.0, 918.0, 873.9000000000001, 827.5950000000001]
.. versionadded:: 3.2 .. versionadded:: 3.2
.. versionchanged:: 3.3
Added the optional *func* parameter.
.. function:: chain(*iterables) .. function:: chain(*iterables)
Make an iterator that returns elements from the first iterable until it is Make an iterator that returns elements from the first iterable until it is

View File

@ -69,11 +69,21 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(list(accumulate('abc')), ['a', 'ab', 'abc']) # works with non-numeric self.assertEqual(list(accumulate('abc')), ['a', 'ab', 'abc']) # works with non-numeric
self.assertEqual(list(accumulate([])), []) # empty iterable self.assertEqual(list(accumulate([])), []) # empty iterable
self.assertEqual(list(accumulate([7])), [7]) # iterable of length one self.assertEqual(list(accumulate([7])), [7]) # iterable of length one
self.assertRaises(TypeError, accumulate, range(10), 5) # too many args self.assertRaises(TypeError, accumulate, range(10), 5, 6) # too many args
self.assertRaises(TypeError, accumulate) # too few args self.assertRaises(TypeError, accumulate) # too few args
self.assertRaises(TypeError, accumulate, x=range(10)) # unexpected kwd arg self.assertRaises(TypeError, accumulate, x=range(10)) # unexpected kwd arg
self.assertRaises(TypeError, list, accumulate([1, []])) # args that don't add self.assertRaises(TypeError, list, accumulate([1, []])) # args that don't add
s = [2, 8, 9, 5, 7, 0, 3, 4, 1, 6]
self.assertEqual(list(accumulate(s, min)),
[2, 2, 2, 2, 2, 0, 0, 0, 0, 0])
self.assertEqual(list(accumulate(s, max)),
[2, 8, 9, 9, 9, 9, 9, 9, 9, 9])
self.assertEqual(list(accumulate(s, operator.mul)),
[2, 16, 144, 720, 5040, 0, 0, 0, 0, 0])
with self.assertRaises(TypeError):
list(accumulate(s, chr)) # unary-operation
def test_chain(self): def test_chain(self):
def chain2(*iterables): def chain2(*iterables):

View File

@ -89,6 +89,9 @@ Library
- Issue #11696: Fix ID generation in msilib. - Issue #11696: Fix ID generation in msilib.
- itertools.accumulate now supports an optional *func* argument for
a user-supplied binary function.
- Issue #11692: Remove unnecessary demo functions in subprocess module. - Issue #11692: Remove unnecessary demo functions in subprocess module.
- Issue #9696: Fix exception incorrectly raised by xdrlib.Packer.pack_int when - Issue #9696: Fix exception incorrectly raised by xdrlib.Packer.pack_int when

View File

@ -2590,6 +2590,7 @@ typedef struct {
PyObject_HEAD PyObject_HEAD
PyObject *total; PyObject *total;
PyObject *it; PyObject *it;
PyObject *binop;
} accumulateobject; } accumulateobject;
static PyTypeObject accumulate_type; static PyTypeObject accumulate_type;
@ -2597,12 +2598,14 @@ static PyTypeObject accumulate_type;
static PyObject * static PyObject *
accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds) accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{ {
static char *kwargs[] = {"iterable", NULL}; static char *kwargs[] = {"iterable", "func", NULL};
PyObject *iterable; PyObject *iterable;
PyObject *it; PyObject *it;
PyObject *binop = NULL;
accumulateobject *lz; accumulateobject *lz;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:accumulate", kwargs, &iterable)) if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:accumulate",
kwargs, &iterable, &binop))
return NULL; return NULL;
/* Get iterator. */ /* Get iterator. */
@ -2617,6 +2620,8 @@ accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
return NULL; return NULL;
} }
Py_XINCREF(binop);
lz->binop = binop;
lz->total = NULL; lz->total = NULL;
lz->it = it; lz->it = it;
return (PyObject *)lz; return (PyObject *)lz;
@ -2626,6 +2631,7 @@ static void
accumulate_dealloc(accumulateobject *lz) accumulate_dealloc(accumulateobject *lz)
{ {
PyObject_GC_UnTrack(lz); PyObject_GC_UnTrack(lz);
Py_XDECREF(lz->binop);
Py_XDECREF(lz->total); Py_XDECREF(lz->total);
Py_XDECREF(lz->it); Py_XDECREF(lz->it);
Py_TYPE(lz)->tp_free(lz); Py_TYPE(lz)->tp_free(lz);
@ -2634,6 +2640,7 @@ accumulate_dealloc(accumulateobject *lz)
static int static int
accumulate_traverse(accumulateobject *lz, visitproc visit, void *arg) accumulate_traverse(accumulateobject *lz, visitproc visit, void *arg)
{ {
Py_VISIT(lz->binop);
Py_VISIT(lz->it); Py_VISIT(lz->it);
Py_VISIT(lz->total); Py_VISIT(lz->total);
return 0; return 0;
@ -2653,8 +2660,11 @@ accumulate_next(accumulateobject *lz)
lz->total = val; lz->total = val;
return lz->total; return lz->total;
} }
newtotal = PyNumber_Add(lz->total, val); if (lz->binop == NULL)
newtotal = PyNumber_Add(lz->total, val);
else
newtotal = PyObject_CallFunctionObjArgs(lz->binop, lz->total, val, NULL);
Py_DECREF(val); Py_DECREF(val);
if (newtotal == NULL) if (newtotal == NULL)
return NULL; return NULL;