GH-98363: Add itertools.batched() (GH-98364)

This commit is contained in:
Raymond Hettinger 2022-10-17 18:53:45 -05:00 committed by GitHub
parent 70732d8a4c
commit de3ece769a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 370 additions and 39 deletions

View File

@ -48,6 +48,7 @@ Iterator Arguments Results
Iterator Arguments Results Example
============================ ============================ ================================================= =============================================================
:func:`accumulate` p [,func] p0, p0+p1, p0+p1+p2, ... ``accumulate([1,2,3,4,5]) --> 1 3 6 10 15``
:func:`batched` p, n [p0, p1, ..., p_n-1], ... ``batched('ABCDEFG', n=3) --> ABC DEF G``
:func:`chain` p, q, ... p0, p1, ... plast, q0, q1, ... ``chain('ABC', 'DEF') --> A B C D E F``
:func:`chain.from_iterable` iterable p0, p1, ... plast, q0, q1, ... ``chain.from_iterable(['ABC', 'DEF']) --> A B C D E F``
:func:`compress` data, selectors (d[0] if s[0]), (d[1] if s[1]), ... ``compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F``
@ -170,6 +171,44 @@ loops that truncate the stream.
.. versionchanged:: 3.8
Added the optional *initial* parameter.
.. function:: batched(iterable, n)
Batch data from the *iterable* into lists of length *n*. The last
batch may be shorter than *n*.
Loops over the input iterable and accumulates data into lists up to
size *n*. The input is consumed lazily, just enough to fill a list.
The result is yielded as soon as the batch is full or when the input
iterable is exhausted:
.. doctest::
>>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet']
>>> unflattened = list(batched(flattened_data, 2))
>>> unflattened
[['roses', 'red'], ['violets', 'blue'], ['sugar', 'sweet']]
>>> for batch in batched('ABCDEFG', 3):
... print(batch)
...
['A', 'B', 'C']
['D', 'E', 'F']
['G']
Roughly equivalent to::
def batched(iterable, n):
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError('n must be at least one')
it = iter(iterable)
while (batch := list(islice(it, n))):
yield batch
.. versionadded:: 3.12
.. function:: chain(*iterables)
Make an iterator that returns elements from the first iterable until it is
@ -858,13 +897,6 @@ which incur interpreter overhead.
else:
raise ValueError('Expected fill, strict, or ignore')
def batched(iterable, n):
"Batch data into lists of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
it = iter(iterable)
while (batch := list(islice(it, n))):
yield batch
def triplewise(iterable):
"Return overlapping triplets from an iterable"
# triplewise('ABCDEFG') --> ABC BCD CDE DEF EFG
@ -1211,36 +1243,6 @@ which incur interpreter overhead.
>>> list(grouper('abcdefg', n=3, incomplete='ignore'))
[('a', 'b', 'c'), ('d', 'e', 'f')]
>>> list(batched('ABCDEFG', 3))
[['A', 'B', 'C'], ['D', 'E', 'F'], ['G']]
>>> list(batched('ABCDEF', 3))
[['A', 'B', 'C'], ['D', 'E', 'F']]
>>> list(batched('ABCDE', 3))
[['A', 'B', 'C'], ['D', 'E']]
>>> list(batched('ABCD', 3))
[['A', 'B', 'C'], ['D']]
>>> list(batched('ABC', 3))
[['A', 'B', 'C']]
>>> list(batched('AB', 3))
[['A', 'B']]
>>> list(batched('A', 3))
[['A']]
>>> list(batched('', 3))
[]
>>> list(batched('ABCDEFG', 2))
[['A', 'B'], ['C', 'D'], ['E', 'F'], ['G']]
>>> list(batched('ABCDEFG', 1))
[['A'], ['B'], ['C'], ['D'], ['E'], ['F'], ['G']]
>>> list(batched('ABCDEFG', 0))
[]
>>> list(batched('ABCDEFG', -1))
Traceback (most recent call last):
...
ValueError: Stop argument for islice() must be None or an integer: 0 <= x <= sys.maxsize.
>>> s = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
>>> all(list(flatten(batched(s[:n], 5))) == list(s[:n]) for n in range(len(s)))
True
>>> list(triplewise('ABCDEFG'))
[('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E'), ('D', 'E', 'F'), ('E', 'F', 'G')]

View File

@ -159,6 +159,44 @@ class TestBasicOps(unittest.TestCase):
with self.assertRaises(TypeError):
list(accumulate([10, 20], 100))
def test_batched(self):
self.assertEqual(list(batched('ABCDEFG', 3)),
[['A', 'B', 'C'], ['D', 'E', 'F'], ['G']])
self.assertEqual(list(batched('ABCDEFG', 2)),
[['A', 'B'], ['C', 'D'], ['E', 'F'], ['G']])
self.assertEqual(list(batched('ABCDEFG', 1)),
[['A'], ['B'], ['C'], ['D'], ['E'], ['F'], ['G']])
with self.assertRaises(TypeError): # Too few arguments
list(batched('ABCDEFG'))
with self.assertRaises(TypeError):
list(batched('ABCDEFG', 3, None)) # Too many arguments
with self.assertRaises(TypeError):
list(batched(None, 3)) # Non-iterable input
with self.assertRaises(TypeError):
list(batched('ABCDEFG', 'hello')) # n is a string
with self.assertRaises(ValueError):
list(batched('ABCDEFG', 0)) # n is zero
with self.assertRaises(ValueError):
list(batched('ABCDEFG', -1)) # n is negative
data = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
for n in range(1, 6):
for i in range(len(data)):
s = data[:i]
batches = list(batched(s, n))
with self.subTest(s=s, n=n, batches=batches):
# Order is preserved and no data is lost
self.assertEqual(''.join(chain(*batches)), s)
# Each batch is an exact list
self.assertTrue(all(type(batch) is list for batch in batches))
# All but the last batch is of size n
if batches:
last_batch = batches.pop()
self.assertTrue(all(len(batch) == n for batch in batches))
self.assertTrue(len(last_batch) <= n)
batches.append(last_batch)
def test_chain(self):
def chain2(*iterables):
@ -1737,6 +1775,31 @@ class TestExamples(unittest.TestCase):
class TestPurePythonRoughEquivalents(unittest.TestCase):
def test_batched_recipe(self):
def batched_recipe(iterable, n):
"Batch data into lists of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError('n must be at least one')
it = iter(iterable)
while (batch := list(islice(it, n))):
yield batch
for iterable, n in product(
['', 'a', 'ab', 'abc', 'abcd', 'abcde', 'abcdef', 'abcdefg', None],
[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, None]):
with self.subTest(iterable=iterable, n=n):
try:
e1, r1 = None, list(batched(iterable, n))
except Exception as e:
e1, r1 = type(e), None
try:
e2, r2 = None, list(batched_recipe(iterable, n))
except Exception as e:
e2, r2 = type(e), None
self.assertEqual(r1, r2)
self.assertEqual(e1, e2)
@staticmethod
def islice(iterable, *args):
s = slice(*args)
@ -1788,6 +1851,10 @@ class TestGC(unittest.TestCase):
a = []
self.makecycle(accumulate([1,2,a,3]), a)
def test_batched(self):
a = []
self.makecycle(batched([1,2,a,3], 2), a)
def test_chain(self):
a = []
self.makecycle(chain(a), a)
@ -1972,6 +2039,18 @@ class TestVariousIteratorArgs(unittest.TestCase):
self.assertRaises(TypeError, accumulate, N(s))
self.assertRaises(ZeroDivisionError, list, accumulate(E(s)))
def test_batched(self):
s = 'abcde'
r = [['a', 'b'], ['c', 'd'], ['e']]
n = 2
for g in (G, I, Ig, L, R):
with self.subTest(g=g):
self.assertEqual(list(batched(g(s), n)), r)
self.assertEqual(list(batched(S(s), 2)), [])
self.assertRaises(TypeError, batched, X(s), 2)
self.assertRaises(TypeError, batched, N(s), 2)
self.assertRaises(ZeroDivisionError, list, batched(E(s), 2))
def test_chain(self):
for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
for g in (G, I, Ig, S, L, R):

View File

@ -0,0 +1,2 @@
Added itertools.batched() to batch data into lists of a given length with
the last list possibly being shorter than the others.

View File

@ -8,6 +8,85 @@ preserve
#endif
PyDoc_STRVAR(batched_new__doc__,
"batched(iterable, n)\n"
"--\n"
"\n"
"Batch data into lists of length n. The last batch may be shorter than n.\n"
"\n"
"Loops over the input iterable and accumulates data into lists\n"
"up to size n. The input is consumed lazily, just enough to\n"
"fill a list. The result is yielded as soon as a batch is full\n"
"or when the input iterable is exhausted.\n"
"\n"
" >>> for batch in batched(\'ABCDEFG\', 3):\n"
" ... print(batch)\n"
" ...\n"
" [\'A\', \'B\', \'C\']\n"
" [\'D\', \'E\', \'F\']\n"
" [\'G\']");
static PyObject *
batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n);
static PyObject *
batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
{
PyObject *return_value = NULL;
#if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE)
#define NUM_KEYWORDS 2
static struct {
PyGC_Head _this_is_not_used;
PyObject_VAR_HEAD
PyObject *ob_item[NUM_KEYWORDS];
} _kwtuple = {
.ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS)
.ob_item = { &_Py_ID(iterable), &_Py_ID(n), },
};
#undef NUM_KEYWORDS
#define KWTUPLE (&_kwtuple.ob_base.ob_base)
#else // !Py_BUILD_CORE
# define KWTUPLE NULL
#endif // !Py_BUILD_CORE
static const char * const _keywords[] = {"iterable", "n", NULL};
static _PyArg_Parser _parser = {
.keywords = _keywords,
.fname = "batched",
.kwtuple = KWTUPLE,
};
#undef KWTUPLE
PyObject *argsbuf[2];
PyObject * const *fastargs;
Py_ssize_t nargs = PyTuple_GET_SIZE(args);
PyObject *iterable;
Py_ssize_t n;
fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 2, 2, 0, argsbuf);
if (!fastargs) {
goto exit;
}
iterable = fastargs[0];
{
Py_ssize_t ival = -1;
PyObject *iobj = _PyNumber_Index(fastargs[1]);
if (iobj != NULL) {
ival = PyLong_AsSsize_t(iobj);
Py_DECREF(iobj);
}
if (ival == -1 && PyErr_Occurred()) {
goto exit;
}
n = ival;
}
return_value = batched_new_impl(type, iterable, n);
exit:
return return_value;
}
PyDoc_STRVAR(pairwise_new__doc__,
"pairwise(iterable, /)\n"
"--\n"
@ -834,4 +913,4 @@ skip_optional_pos:
exit:
return return_value;
}
/*[clinic end generated code: output=b1056d63f68a9059 input=a9049054013a1b77]*/
/*[clinic end generated code: output=efea8cd1e647bd17 input=a9049054013a1b77]*/

View File

@ -16,6 +16,7 @@ class itertools.groupby "groupbyobject *" "&groupby_type"
class itertools._grouper "_grouperobject *" "&_grouper_type"
class itertools.teedataobject "teedataobject *" "&teedataobject_type"
class itertools._tee "teeobject *" "&tee_type"
class itertools.batched "batchedobject *" "&batched_type"
class itertools.cycle "cycleobject *" "&cycle_type"
class itertools.dropwhile "dropwhileobject *" "&dropwhile_type"
class itertools.takewhile "takewhileobject *" "&takewhile_type"
@ -30,12 +31,13 @@ class itertools.filterfalse "filterfalseobject *" "&filterfalse_type"
class itertools.count "countobject *" "&count_type"
class itertools.pairwise "pairwiseobject *" "&pairwise_type"
[clinic start generated code]*/
/*[clinic end generated code: output=da39a3ee5e6b4b0d input=6498ed21fbe1bf94]*/
/*[clinic end generated code: output=da39a3ee5e6b4b0d input=1168b274011ce21b]*/
static PyTypeObject groupby_type;
static PyTypeObject _grouper_type;
static PyTypeObject teedataobject_type;
static PyTypeObject tee_type;
static PyTypeObject batched_type;
static PyTypeObject cycle_type;
static PyTypeObject dropwhile_type;
static PyTypeObject takewhile_type;
@ -51,6 +53,171 @@ static PyTypeObject pairwise_type;
#include "clinic/itertoolsmodule.c.h"
/* batched object ************************************************************/
/* Note: The built-in zip() function includes a "strict" argument
that is needed because that function can silently truncate data
and there is no easy way for a user to detect that condition.
The same reasoning does not apply to batches() which never drops
data. Instead, it produces a shorter list which can be handled
as the user sees fit.
*/
typedef struct {
PyObject_HEAD
PyObject *it;
Py_ssize_t batch_size;
} batchedobject;
/*[clinic input]
@classmethod
itertools.batched.__new__ as batched_new
iterable: object
n: Py_ssize_t
Batch data into lists of length n. The last batch may be shorter than n.
Loops over the input iterable and accumulates data into lists
up to size n. The input is consumed lazily, just enough to
fill a list. The result is yielded as soon as a batch is full
or when the input iterable is exhausted.
>>> for batch in batched('ABCDEFG', 3):
... print(batch)
...
['A', 'B', 'C']
['D', 'E', 'F']
['G']
[clinic start generated code]*/
static PyObject *
batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n)
/*[clinic end generated code: output=7ebc954d655371b6 input=f28fd12cb52365f0]*/
{
PyObject *it;
batchedobject *bo;
if (n < 1) {
/* We could define the n==0 case to return an empty iterator
but that is add odds with the idea that batching should
never throw-away input data.
*/
PyErr_SetString(PyExc_ValueError, "n must be at least one");
return NULL;
}
it = PyObject_GetIter(iterable);
if (it == NULL) {
return NULL;
}
/* create batchedobject structure */
bo = (batchedobject *)type->tp_alloc(type, 0);
if (bo == NULL) {
Py_DECREF(it);
return NULL;
}
bo->batch_size = n;
bo->it = it;
return (PyObject *)bo;
}
static void
batched_dealloc(batchedobject *bo)
{
PyObject_GC_UnTrack(bo);
Py_XDECREF(bo->it);
Py_TYPE(bo)->tp_free(bo);
}
static int
batched_traverse(batchedobject *bo, visitproc visit, void *arg)
{
if (bo->it != NULL) {
Py_VISIT(bo->it);
}
return 0;
}
static PyObject *
batched_next(batchedobject *bo)
{
Py_ssize_t i;
PyObject *it = bo->it;
PyObject *item;
PyObject *result;
if (it == NULL) {
return NULL;
}
result = PyList_New(0);
if (result == NULL) {
return NULL;
}
for (i=0 ; i < bo->batch_size ; i++) {
item = PyIter_Next(it);
if (item == NULL) {
break;
}
if (PyList_Append(result, item) < 0) {
Py_DECREF(item);
Py_DECREF(result);
return NULL;
}
Py_DECREF(item);
}
if (PyList_GET_SIZE(result) > 0) {
return result;
}
Py_CLEAR(bo->it);
Py_DECREF(result);
return NULL;
}
static PyTypeObject batched_type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
"itertools.batched", /* tp_name */
sizeof(batchedobject), /* tp_basicsize */
0, /* tp_itemsize */
/* methods */
(destructor)batched_dealloc, /* tp_dealloc */
0, /* tp_vectorcall_offset */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_as_async */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
PyObject_GenericGetAttr, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_BASETYPE, /* tp_flags */
batched_new__doc__, /* tp_doc */
(traverseproc)batched_traverse, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
(iternextfunc)batched_next, /* tp_iternext */
0, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
PyType_GenericAlloc, /* tp_alloc */
batched_new, /* tp_new */
PyObject_GC_Del, /* tp_free */
};
/* pairwise object ***********************************************************/
typedef struct {
@ -4815,6 +4982,7 @@ repeat(elem [,n]) --> elem, elem, elem, ... endlessly or up to n times\n\
\n\
Iterators terminating on the shortest input sequence:\n\
accumulate(p[, func]) --> p0, p0+p1, p0+p1+p2\n\
batched(p, n) --> [p0, p1, ..., p_n-1], [p_n, p_n+1, ..., p_2n-1], ...\n\
chain(p, q, ...) --> p0, p1, ... plast, q0, q1, ...\n\
chain.from_iterable([p, q, ...]) --> p0, p1, ... plast, q0, q1, ...\n\
compress(data, selectors) --> (d[0] if s[0]), (d[1] if s[1]), ...\n\
@ -4841,6 +5009,7 @@ itertoolsmodule_exec(PyObject *m)
{
PyTypeObject *typelist[] = {
&accumulate_type,
&batched_type,
&combinations_type,
&cwr_type,
&cycle_type,