Add itertools.accumulate().
This commit is contained in:
parent
2f9a77a389
commit
482ba77245
|
@ -56,6 +56,23 @@ def fact(n):
|
|||
return prod(range(1, n+1))
|
||||
|
||||
class TestBasicOps(unittest.TestCase):
|
||||
|
||||
def test_accumulate(self):
|
||||
self.assertEqual(list(accumulate(range(10))), # one positional arg
|
||||
[0, 1, 3, 6, 10, 15, 21, 28, 36, 45])
|
||||
self.assertEqual(list(accumulate(range(10), 100)), # two positional args
|
||||
[100, 101, 103, 106, 110, 115, 121, 128, 136, 145])
|
||||
self.assertEqual(list(accumulate(iterable=range(10), start=100)), # kw args
|
||||
[100, 101, 103, 106, 110, 115, 121, 128, 136, 145])
|
||||
for typ in int, complex, Decimal, Fraction: # multiple types
|
||||
self.assertEqual(list(accumulate(range(10), typ(0))),
|
||||
list(map(typ, [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])))
|
||||
self.assertEqual(list(accumulate([])), []) # empty iterable
|
||||
self.assertRaises(TypeError, accumulate, range(10), 0, 5) # too many args
|
||||
self.assertRaises(TypeError, accumulate) # too few args
|
||||
self.assertRaises(TypeError, accumulate, range(10), x=7) # unexpected kwd args
|
||||
self.assertRaises(TypeError, list, accumulate([1, []])) # args that don't add
|
||||
|
||||
def test_chain(self):
|
||||
|
||||
def chain2(*iterables):
|
||||
|
@ -932,6 +949,9 @@ class TestBasicOps(unittest.TestCase):
|
|||
|
||||
class TestExamples(unittest.TestCase):
|
||||
|
||||
def test_accumlate(self):
|
||||
self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15])
|
||||
|
||||
def test_chain(self):
|
||||
self.assertEqual(''.join(chain('ABC', 'DEF')), 'ABCDEF')
|
||||
|
||||
|
@ -1019,6 +1039,10 @@ class TestGC(unittest.TestCase):
|
|||
next(iterator)
|
||||
del container, iterator
|
||||
|
||||
def test_accumulate(self):
|
||||
a = []
|
||||
self.makecycle(accumulate([1,2,a,3]), a)
|
||||
|
||||
def test_chain(self):
|
||||
a = []
|
||||
self.makecycle(chain(a), a)
|
||||
|
@ -1188,6 +1212,17 @@ def L(seqn):
|
|||
|
||||
class TestVariousIteratorArgs(unittest.TestCase):
|
||||
|
||||
def test_accumulate(self):
|
||||
s = [1,2,3,4,5]
|
||||
r = [1,3,6,10,15]
|
||||
n = len(s)
|
||||
for g in (G, I, Ig, L, R):
|
||||
self.assertEqual(list(accumulate(g(s))), r)
|
||||
self.assertEqual(list(accumulate(S(s))), [])
|
||||
self.assertRaises(TypeError, accumulate, X(s))
|
||||
self.assertRaises(TypeError, accumulate, N(s))
|
||||
self.assertRaises(ZeroDivisionError, list, accumulate(E(s)))
|
||||
|
||||
def test_chain(self):
|
||||
for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
|
||||
for g in (G, I, Ig, S, L, R):
|
||||
|
|
|
@ -46,6 +46,8 @@ Core and Builtins
|
|||
Library
|
||||
-------
|
||||
|
||||
- Added itertools.accumulate().
|
||||
|
||||
- Issue #4113: Added custom ``__repr__`` method to ``functools.partial``.
|
||||
Original patch by Daniel Urban.
|
||||
|
||||
|
|
|
@ -2584,6 +2584,146 @@ static PyTypeObject permutations_type = {
|
|||
PyObject_GC_Del, /* tp_free */
|
||||
};
|
||||
|
||||
/* accumulate object ************************************************************/
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
PyObject *total;
|
||||
PyObject *it;
|
||||
} accumulateobject;
|
||||
|
||||
static PyTypeObject accumulate_type;
|
||||
|
||||
static PyObject *
|
||||
accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
|
||||
{
|
||||
static char *kwargs[] = {"iterable", "start", NULL};
|
||||
PyObject *iterable;
|
||||
PyObject *it;
|
||||
PyObject *start = NULL;
|
||||
accumulateobject *lz;
|
||||
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:accumulate",
|
||||
kwargs, &iterable, &start))
|
||||
return NULL;
|
||||
|
||||
/* Get iterator. */
|
||||
it = PyObject_GetIter(iterable);
|
||||
if (it == NULL)
|
||||
return NULL;
|
||||
|
||||
/* Default start value */
|
||||
if (start == NULL) {
|
||||
start = PyLong_FromLong(0);
|
||||
if (start == NULL) {
|
||||
Py_DECREF(it);
|
||||
return NULL;
|
||||
}
|
||||
} else {
|
||||
Py_INCREF(start);
|
||||
}
|
||||
|
||||
/* create accumulateobject structure */
|
||||
lz = (accumulateobject *)type->tp_alloc(type, 0);
|
||||
if (lz == NULL) {
|
||||
Py_DECREF(it);
|
||||
Py_DECREF(start);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
lz->total = start;
|
||||
lz->it = it;
|
||||
return (PyObject *)lz;
|
||||
}
|
||||
|
||||
static void
|
||||
accumulate_dealloc(accumulateobject *lz)
|
||||
{
|
||||
PyObject_GC_UnTrack(lz);
|
||||
Py_XDECREF(lz->total);
|
||||
Py_XDECREF(lz->it);
|
||||
Py_TYPE(lz)->tp_free(lz);
|
||||
}
|
||||
|
||||
static int
|
||||
accumulate_traverse(accumulateobject *lz, visitproc visit, void *arg)
|
||||
{
|
||||
Py_VISIT(lz->it);
|
||||
Py_VISIT(lz->total);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
accumulate_next(accumulateobject *lz)
|
||||
{
|
||||
PyObject *val, *oldtotal, *newtotal;
|
||||
|
||||
val = PyIter_Next(lz->it);
|
||||
if (val == NULL)
|
||||
return NULL;
|
||||
|
||||
newtotal = PyNumber_Add(lz->total, val);
|
||||
Py_DECREF(val);
|
||||
if (newtotal == NULL)
|
||||
return NULL;
|
||||
|
||||
oldtotal = lz->total;
|
||||
lz->total = newtotal;
|
||||
Py_DECREF(oldtotal);
|
||||
|
||||
Py_INCREF(newtotal);
|
||||
return newtotal;
|
||||
}
|
||||
|
||||
PyDoc_STRVAR(accumulate_doc,
|
||||
"accumulate(iterable, start=0) --> accumulate object\n\
|
||||
\n\
|
||||
Return series of accumulated sums.");
|
||||
|
||||
static PyTypeObject accumulate_type = {
|
||||
PyVarObject_HEAD_INIT(NULL, 0)
|
||||
"itertools.accumulate", /* tp_name */
|
||||
sizeof(accumulateobject), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
/* methods */
|
||||
(destructor)accumulate_dealloc, /* tp_dealloc */
|
||||
0, /* tp_print */
|
||||
0, /* tp_getattr */
|
||||
0, /* tp_setattr */
|
||||
0, /* tp_reserved */
|
||||
0, /* tp_repr */
|
||||
0, /* tp_as_number */
|
||||
0, /* tp_as_sequence */
|
||||
0, /* tp_as_mapping */
|
||||
0, /* tp_hash */
|
||||
0, /* tp_call */
|
||||
0, /* tp_str */
|
||||
PyObject_GenericGetAttr, /* tp_getattro */
|
||||
0, /* tp_setattro */
|
||||
0, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
|
||||
Py_TPFLAGS_BASETYPE, /* tp_flags */
|
||||
accumulate_doc, /* tp_doc */
|
||||
(traverseproc)accumulate_traverse, /* tp_traverse */
|
||||
0, /* tp_clear */
|
||||
0, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
PyObject_SelfIter, /* tp_iter */
|
||||
(iternextfunc)accumulate_next, /* tp_iternext */
|
||||
0, /* tp_methods */
|
||||
0, /* tp_members */
|
||||
0, /* tp_getset */
|
||||
0, /* tp_base */
|
||||
0, /* tp_dict */
|
||||
0, /* tp_descr_get */
|
||||
0, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
0, /* tp_init */
|
||||
0, /* tp_alloc */
|
||||
accumulate_new, /* tp_new */
|
||||
PyObject_GC_Del, /* tp_free */
|
||||
};
|
||||
|
||||
|
||||
/* compress object ************************************************************/
|
||||
|
||||
|
@ -3496,6 +3636,7 @@ cycle(p) --> p0, p1, ... plast, p0, p1, ...\n\
|
|||
repeat(elem [,n]) --> elem, elem, elem, ... endlessly or up to n times\n\
|
||||
\n\
|
||||
Iterators terminating on the shortest input sequence:\n\
|
||||
accumulate(p, start=0) --> p0, p0+p1, p0+p1+p2\n\
|
||||
chain(p, q, ...) --> p0, p1, ... plast, q0, q1, ... \n\
|
||||
compress(data, selectors) --> (d[0] if s[0]), (d[1] if s[1]), ...\n\
|
||||
dropwhile(pred, seq) --> seq[n], seq[n+1], starting when pred fails\n\
|
||||
|
@ -3541,6 +3682,7 @@ PyInit_itertools(void)
|
|||
PyObject *m;
|
||||
char *name;
|
||||
PyTypeObject *typelist[] = {
|
||||
&accumulate_type,
|
||||
&combinations_type,
|
||||
&cwr_type,
|
||||
&cycle_type,
|
||||
|
|
Loading…
Reference in New Issue