Implement itertools.groupby()

Original idea by Guido van Rossum.
Idea for skipable inner iterators by Raymond Hettinger.
Idea for argument order and identity function default by Alex Martelli.
Implementation by Hye-Shik Chang (with tweaks by Raymond Hettinger).
This commit is contained in:
Raymond Hettinger 2003-12-06 16:23:06 +00:00
parent b8d5f245b7
commit d25c1c6351
4 changed files with 493 additions and 2 deletions

View File

@ -130,6 +130,54 @@ by functions or loops that truncate the stream.
\end{verbatim} \end{verbatim}
\end{funcdesc} \end{funcdesc}
\begin{funcdesc}{groupby}{iterable\optional{, key}}
Make an iterator that returns consecutive keys and groups from the
\var{iterable}. \var{key} is function computing a key value for each
element. If not specified or is \code{None}, \var{key} defaults to an
identity function (returning the element unchanged). Generally, the
iterable needs to already be sorted on the same key function.
The returned group is itself an iterator that shares the underlying
iterable with \function{groupby()}. Because the source is shared, when
the \function{groupby} object is advanced, the previous group is no
longer visible. So, if that data is needed later, it should be stored
as a list:
\begin{verbatim}
groups = []
uniquekeys = []
for k, g in groupby(data, keyfunc):
groups.append(list(g)) # Store group iterator as a list
uniquekeys.append(k)
\end{verbatim}
\function{groupby()} is equivalent to:
\begin{verbatim}
class groupby(object):
def __init__(self, iterable, key=None):
if key is None:
key = lambda x: x
self.keyfunc = key
self.it = iter(iterable)
self.tgtkey = self.currkey = self.currvalue = xrange(0)
def __iter__(self):
return self
def next(self):
while self.currkey == self.tgtkey:
self.currvalue = self.it.next() # Exit on StopIteration
self.currkey = self.keyfunc(self.currvalue)
self.tgtkey = self.currkey
return (self.currkey, self._grouper(self.tgtkey))
def _grouper(self, tgtkey):
while self.currkey == tgtkey:
yield self.currvalue
self.currvalue = self.it.next() # Exit on StopIteration
self.currkey = self.keyfunc(self.currvalue)
\end{verbatim}
\versionadded{2.4}
\end{funcdesc}
\begin{funcdesc}{ifilter}{predicate, iterable} \begin{funcdesc}{ifilter}{predicate, iterable}
Make an iterator that filters elements from iterable returning only Make an iterator that filters elements from iterable returning only
those for which the predicate is \code{True}. those for which the predicate is \code{True}.
@ -346,6 +394,18 @@ Martin
Walter Walter
Samuele Samuele
# Show a dictionary sorted and grouped by value
>>> from operator import itemgetter
>>> d = dict(a=1, b=2, c=1, d=2, e=1, f=2, g=3)
>>> di = list.sorted(d.iteritems(), key=itemgetter(1))
>>> for k, g in groupby(di, key=itemgetter(1)):
... print k, map(itemgetter(0), g)
...
1 ['a', 'c', 'e']
2 ['b', 'd', 'f']
3 ['g']
\end{verbatim} \end{verbatim}
This section shows how itertools can be combined to create other more This section shows how itertools can be combined to create other more

View File

@ -61,6 +61,94 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(TypeError, cycle, 5) self.assertRaises(TypeError, cycle, 5)
self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0]) self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0])
def test_groupby(self):
# Check whether it accepts arguments correctly
self.assertEqual([], list(groupby([])))
self.assertEqual([], list(groupby([], key=id)))
self.assertRaises(TypeError, list, groupby('abc', []))
self.assertRaises(TypeError, groupby, None)
# Check normal input
s = [(0, 10, 20), (0, 11,21), (0,12,21), (1,13,21), (1,14,22),
(2,15,22), (3,16,23), (3,17,23)]
dup = []
for k, g in groupby(s, lambda r:r[0]):
for elem in g:
self.assertEqual(k, elem[0])
dup.append(elem)
self.assertEqual(s, dup)
# Check nested case
dup = []
for k, g in groupby(s, lambda r:r[0]):
for ik, ig in groupby(g, lambda r:r[2]):
for elem in ig:
self.assertEqual(k, elem[0])
self.assertEqual(ik, elem[2])
dup.append(elem)
self.assertEqual(s, dup)
# Check case where inner iterator is not used
keys = [k for k, g in groupby(s, lambda r:r[0])]
expectedkeys = set([r[0] for r in s])
self.assertEqual(set(keys), expectedkeys)
self.assertEqual(len(keys), len(expectedkeys))
# Exercise pipes and filters style
s = 'abracadabra'
# sort s | uniq
r = [k for k, g in groupby(list.sorted(s))]
self.assertEqual(r, ['a', 'b', 'c', 'd', 'r'])
# sort s | uniq -d
r = [k for k, g in groupby(list.sorted(s)) if list(islice(g,1,2))]
self.assertEqual(r, ['a', 'b', 'r'])
# sort s | uniq -c
r = [(len(list(g)), k) for k, g in groupby(list.sorted(s))]
self.assertEqual(r, [(5, 'a'), (2, 'b'), (1, 'c'), (1, 'd'), (2, 'r')])
# sort s | uniq -c | sort -rn | head -3
r = list.sorted([(len(list(g)) , k) for k, g in groupby(list.sorted(s))], reverse=True)[:3]
self.assertEqual(r, [(5, 'a'), (2, 'r'), (2, 'b')])
# iter.next failure
class ExpectedError(Exception):
pass
def delayed_raise(n=0):
for i in range(n):
yield 'yo'
raise ExpectedError
def gulp(iterable, keyp=None, func=list):
return [func(g) for k, g in groupby(iterable, keyp)]
# iter.next failure on outer object
self.assertRaises(ExpectedError, gulp, delayed_raise(0))
# iter.next failure on inner object
self.assertRaises(ExpectedError, gulp, delayed_raise(1))
# __cmp__ failure
class DummyCmp:
def __cmp__(self, dst):
raise ExpectedError
s = [DummyCmp(), DummyCmp(), None]
# __cmp__ failure on outer object
self.assertRaises(ExpectedError, gulp, s, func=id)
# __cmp__ failure on inner object
self.assertRaises(ExpectedError, gulp, s)
# keyfunc failure
def keyfunc(obj):
if keyfunc.skip > 0:
keyfunc.skip -= 1
return obj
else:
raise ExpectedError
# keyfunc failure on outer object
keyfunc.skip = 0
self.assertRaises(ExpectedError, gulp, [None], keyfunc)
keyfunc.skip = 1
self.assertRaises(ExpectedError, gulp, [None, None], keyfunc)
def test_ifilter(self): def test_ifilter(self):
self.assertEqual(list(ifilter(isEven, range(6))), [0,2,4]) self.assertEqual(list(ifilter(isEven, range(6))), [0,2,4])
self.assertEqual(list(ifilter(None, [0,1,0,2,0])), [1,2]) self.assertEqual(list(ifilter(None, [0,1,0,2,0])), [1,2])
@ -268,7 +356,7 @@ class TestBasicOps(unittest.TestCase):
def test_StopIteration(self): def test_StopIteration(self):
self.assertRaises(StopIteration, izip().next) self.assertRaises(StopIteration, izip().next)
for f in (chain, cycle, izip): for f in (chain, cycle, izip, groupby):
self.assertRaises(StopIteration, f([]).next) self.assertRaises(StopIteration, f([]).next)
self.assertRaises(StopIteration, f(StopNow()).next) self.assertRaises(StopIteration, f(StopNow()).next)
@ -426,6 +514,14 @@ class TestVariousIteratorArgs(unittest.TestCase):
self.assertRaises(TypeError, list, cycle(N(s))) self.assertRaises(TypeError, list, cycle(N(s)))
self.assertRaises(ZeroDivisionError, list, cycle(E(s))) self.assertRaises(ZeroDivisionError, list, cycle(E(s)))
def test_groupby(self):
for s in (range(10), range(0), range(1000), (7,11), xrange(2000,2200,5)):
for g in (G, I, Ig, S, L, R):
self.assertEqual([k for k, sb in groupby(g(s))], list(g(s)))
self.assertRaises(TypeError, groupby, X(s))
self.assertRaises(TypeError, list, groupby(N(s)))
self.assertRaises(ZeroDivisionError, list, groupby(E(s)))
def test_ifilter(self): def test_ifilter(self):
for s in (range(10), range(0), range(1000), (7,11), xrange(2000,2200,5)): for s in (range(10), range(0), range(1000), (7,11), xrange(2000,2200,5)):
for g in (G, I, Ig, S, L, R): for g in (G, I, Ig, S, L, R):
@ -571,6 +667,16 @@ Martin
Walter Walter
Samuele Samuele
>>> from operator import itemgetter
>>> d = dict(a=1, b=2, c=1, d=2, e=1, f=2, g=3)
>>> di = list.sorted(d.iteritems(), key=itemgetter(1))
>>> for k, g in groupby(di, itemgetter(1)):
... print k, map(itemgetter(0), g)
...
1 ['a', 'c', 'e']
2 ['b', 'd', 'f']
3 ['g']
>>> def take(n, seq): >>> def take(n, seq):
... return list(islice(seq, n)) ... return list(islice(seq, n))

View File

@ -164,6 +164,11 @@ Extension modules
SF bug #812202). Generators that do not define genrandbits() now SF bug #812202). Generators that do not define genrandbits() now
issue a warning when randrange() is called with a range that large. issue a warning when randrange() is called with a range that large.
- itertools has a new function, groupby() for aggregating iterables
into groups sharing the same key (as determined by a key function).
It offers some of functionality of SQL's groupby keyword and of
the Unix uniq filter.
- itertools now has a new function, tee() which produces two independent - itertools now has a new function, tee() which produces two independent
iterators from a single iterable. iterators from a single iterable.

View File

@ -7,6 +7,323 @@
All rights reserved. All rights reserved.
*/ */
/* groupby object ***********************************************************/
typedef struct {
PyObject_HEAD
PyObject *it;
PyObject *keyfunc;
PyObject *tgtkey;
PyObject *currkey;
PyObject *currvalue;
} groupbyobject;
static PyTypeObject groupby_type;
static PyObject *_grouper_create(groupbyobject *, PyObject *);
static PyObject *
groupby_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
static char *kwargs[] = {"iterable", "key", NULL};
groupbyobject *gbo;
PyObject *it, *keyfunc = Py_None;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:groupby", kwargs,
&it, &keyfunc))
return NULL;
gbo = (groupbyobject *)type->tp_alloc(type, 0);
if (gbo == NULL)
return NULL;
gbo->tgtkey = NULL;
gbo->currkey = NULL;
gbo->currvalue = NULL;
gbo->keyfunc = keyfunc;
Py_INCREF(keyfunc);
gbo->it = PyObject_GetIter(it);
if (gbo->it == NULL) {
Py_DECREF(gbo);
return NULL;
}
return (PyObject *)gbo;
}
static void
groupby_dealloc(groupbyobject *gbo)
{
PyObject_GC_UnTrack(gbo);
Py_XDECREF(gbo->it);
Py_XDECREF(gbo->keyfunc);
Py_XDECREF(gbo->tgtkey);
Py_XDECREF(gbo->currkey);
Py_XDECREF(gbo->currvalue);
gbo->ob_type->tp_free(gbo);
}
static int
groupby_traverse(groupbyobject *gbo, visitproc visit, void *arg)
{
int err;
if (gbo->it) {
err = visit(gbo->it, arg);
if (err)
return err;
}
if (gbo->keyfunc) {
err = visit(gbo->keyfunc, arg);
if (err)
return err;
}
if (gbo->tgtkey) {
err = visit(gbo->tgtkey, arg);
if (err)
return err;
}
if (gbo->currkey) {
err = visit(gbo->currkey, arg);
if (err)
return err;
}
if (gbo->currvalue) {
err = visit(gbo->currvalue, arg);
if (err)
return err;
}
return 0;
}
static PyObject *
groupby_next(groupbyobject *gbo)
{
PyObject *newvalue, *newkey, *r, *grouper;
/* skip to next iteration group */
for (;;) {
if (gbo->currkey == NULL)
/* pass */;
else if (gbo->tgtkey == NULL)
break;
else {
int rcmp;
rcmp = PyObject_RichCompareBool(gbo->tgtkey,
gbo->currkey, Py_EQ);
if (rcmp == -1)
return NULL;
else if (rcmp == 0)
break;
}
newvalue = PyIter_Next(gbo->it);
if (newvalue == NULL)
return NULL;
if (gbo->keyfunc == Py_None) {
newkey = newvalue;
Py_INCREF(newvalue);
} else {
newkey = PyObject_CallFunctionObjArgs(gbo->keyfunc,
newvalue, NULL);
if (newkey == NULL) {
Py_DECREF(newvalue);
return NULL;
}
}
Py_XDECREF(gbo->currkey);
gbo->currkey = newkey;
Py_XDECREF(gbo->currvalue);
gbo->currvalue = newvalue;
}
Py_XDECREF(gbo->tgtkey);
gbo->tgtkey = gbo->currkey;
Py_INCREF(gbo->currkey);
grouper = _grouper_create(gbo, gbo->tgtkey);
if (grouper == NULL)
return NULL;
r = PyTuple_Pack(2, gbo->currkey, grouper);
Py_DECREF(grouper);
return r;
}
PyDoc_STRVAR(groupby_doc,
"groupby(iterable[, keyfunc]) -> create an iterator which returns\n\
(key, sub-iterator) grouped by each value of key(value).\n");
static PyTypeObject groupby_type = {
PyObject_HEAD_INIT(NULL)
0, /* ob_size */
"itertools.groupby", /* tp_name */
sizeof(groupbyobject), /* tp_basicsize */
0, /* tp_itemsize */
/* methods */
(destructor)groupby_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
PyObject_GenericGetAttr, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_BASETYPE, /* tp_flags */
groupby_doc, /* tp_doc */
(traverseproc)groupby_traverse, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
(iternextfunc)groupby_next, /* tp_iternext */
0, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
groupby_new, /* tp_new */
PyObject_GC_Del, /* tp_free */
};
/* _grouper object (internal) ************************************************/
typedef struct {
PyObject_HEAD
PyObject *parent;
PyObject *tgtkey;
} _grouperobject;
static PyTypeObject _grouper_type;
static PyObject *
_grouper_create(groupbyobject *parent, PyObject *tgtkey)
{
_grouperobject *igo;
igo = PyObject_New(_grouperobject, &_grouper_type);
if (igo == NULL)
return NULL;
igo->parent = (PyObject *)parent;
Py_INCREF(parent);
igo->tgtkey = tgtkey;
Py_INCREF(tgtkey);
return (PyObject *)igo;
}
static void
_grouper_dealloc(_grouperobject *igo)
{
Py_DECREF(igo->parent);
Py_DECREF(igo->tgtkey);
PyObject_Del(igo);
}
static PyObject *
_grouper_next(_grouperobject *igo)
{
groupbyobject *gbo = (groupbyobject *)igo->parent;
PyObject *newvalue, *newkey, *r;
int rcmp;
if (gbo->currvalue == NULL) {
newvalue = PyIter_Next(gbo->it);
if (newvalue == NULL)
return NULL;
if (gbo->keyfunc == Py_None) {
newkey = newvalue;
Py_INCREF(newvalue);
} else {
newkey = PyObject_CallFunctionObjArgs(gbo->keyfunc,
newvalue, NULL);
if (newkey == NULL) {
Py_DECREF(newvalue);
return NULL;
}
}
assert(gbo->currkey == NULL);
gbo->currkey = newkey;
gbo->currvalue = newvalue;
}
assert(gbo->currkey != NULL);
rcmp = PyObject_RichCompareBool(igo->tgtkey, gbo->currkey, Py_EQ);
if (rcmp <= 0)
/* got any error or current group is end */
return NULL;
r = gbo->currvalue;
gbo->currvalue = NULL;
Py_DECREF(gbo->currkey);
gbo->currkey = NULL;
return r;
}
static PyTypeObject _grouper_type = {
PyObject_HEAD_INIT(NULL)
0, /* ob_size */
"itertools._grouper", /* tp_name */
sizeof(_grouperobject), /* tp_basicsize */
0, /* tp_itemsize */
/* methods */
(destructor)_grouper_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
PyObject_GenericGetAttr, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
0, /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
(iternextfunc)_grouper_next, /* tp_iternext */
0, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
0, /* tp_new */
PyObject_Del, /* tp_free */
};
/* tee object and with supporting function and objects ***************/ /* tee object and with supporting function and objects ***************/
/* The teedataobject pre-allocates space for LINKCELLS number of objects. /* The teedataobject pre-allocates space for LINKCELLS number of objects.
@ -2103,6 +2420,7 @@ tee(it, n=2) --> (it1, it2 , ... itn) splits one iterator into n\n\
chain(p, q, ...) --> p0, p1, ... plast, q0, q1, ... \n\ chain(p, q, ...) --> p0, p1, ... plast, q0, q1, ... \n\
takewhile(pred, seq) --> seq[0], seq[1], until pred fails\n\ takewhile(pred, seq) --> seq[0], seq[1], until pred fails\n\
dropwhile(pred, seq) --> seq[n], seq[n+1], starting when pred fails\n\ dropwhile(pred, seq) --> seq[n], seq[n+1], starting when pred fails\n\
groupby(iterable[, keyfunc]) --> sub-iterators grouped by value of keyfunc(v)\n\
"); ");
@ -2130,6 +2448,7 @@ inititertools(void)
&count_type, &count_type,
&izip_type, &izip_type,
&repeat_type, &repeat_type,
&groupby_type,
NULL NULL
}; };
@ -2148,5 +2467,6 @@ inititertools(void)
return; return;
if (PyType_Ready(&tee_type) < 0) if (PyType_Ready(&tee_type) < 0)
return; return;
if (PyType_Ready(&_grouper_type) < 0)
return;
} }