Added itertools.tee()

It works like the pure python verion except:
* it stops storing data after of the iterators gets deallocated
* the data queue is implemented with two stacks instead of one dictionary.
This commit is contained in:
Raymond Hettinger 2003-10-24 08:45:23 +00:00
parent 16b9fa8db3
commit 6a5b027742
4 changed files with 446 additions and 93 deletions

View File

@ -108,9 +108,8 @@ by functions or loops that truncate the stream.
yield element
\end{verbatim}
Note, this is the only member of the toolkit that may require
significant auxiliary storage (depending on the length of the
iterable).
Note, this member of the toolkit may require significant
auxiliary storage (depending on the length of the iterable).
\end{funcdesc}
\begin{funcdesc}{dropwhile}{predicate, iterable}
@ -282,6 +281,32 @@ by functions or loops that truncate the stream.
\end{verbatim}
\end{funcdesc}
\begin{funcdesc}{tee}{iterable}
Return two independent iterators from a single iterable.
Equivalent to:
\begin{verbatim}
def tee(iterable):
def gen(next, data={}, cnt=[0]):
for i in count():
if i == cnt[0]:
item = data[i] = next()
cnt[0] += 1
else:
item = data.pop(i)
yield item
it = iter(iterable)
return (gen(it.next), gen(it.next))
\end{verbatim}
Note, this member of the toolkit may require significant auxiliary
storage (depending on how much temporary data needs to be stored).
In general, if one iterator is going use most or all of the data before
the other iterator, it is faster to use \function{list()} instead of
\function{tee()}.
\versionadded{2.4}
\end{funcdesc}
\subsection{Examples \label{itertools-example}}
@ -369,6 +394,17 @@ def ncycles(seq, n):
def dotproduct(vec1, vec2):
return sum(imap(operator.mul, vec1, vec2))
def flatten(listOfLists):
return list(chain(*listOfLists))
def repeatfunc(func, times=None, *args):
"Repeat calls to func with specified arguments."
"Example: repeatfunc(random.random)"
if times is None:
return starmap(func, repeat(args))
else:
return starmap(func, repeat(args, times))
def window(seq, n=2):
"Returns a sliding window (of width n) over data from the iterable"
" s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ... "
@ -380,18 +416,4 @@ def window(seq, n=2):
result = result[1:] + (elem,)
yield result
def tee(iterable):
"Return two independent iterators from a single iterable"
def gen(next, data={}, cnt=[0]):
dpop = data.pop
for i in count():
if i == cnt[0]:
item = data[i] = next()
cnt[0] += 1
else:
item = dpop(i)
yield item
next = iter(iterable).next
return (gen(next), gen(next))
\end{verbatim}

View File

@ -3,6 +3,7 @@ from test import test_support
from itertools import *
import sys
import operator
import random
def onearg(x):
'Test function of one argument'
@ -198,6 +199,50 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(TypeError, dropwhile(10, [(4,5)]).next)
self.assertRaises(ValueError, dropwhile(errfunc, [(4,5)]).next)
def test_tee(self):
n = 100
def irange(n):
for i in xrange(n):
yield i
a, b = tee([]) # test empty iterator
self.assertEqual(list(a), [])
self.assertEqual(list(b), [])
a, b = tee(irange(n)) # test 100% interleaved
self.assertEqual(zip(a,b), zip(range(n),range(n)))
a, b = tee(irange(n)) # test 0% interleaved
self.assertEqual(list(a), range(n))
self.assertEqual(list(b), range(n))
a, b = tee(irange(n)) # test dealloc of leading iterator
self.assertEqual(a.next(), 0)
self.assertEqual(a.next(), 1)
del a
self.assertEqual(list(b), range(n))
a, b = tee(irange(n)) # test dealloc of trailing iterator
self.assertEqual(a.next(), 0)
self.assertEqual(a.next(), 1)
del b
self.assertEqual(list(a), range(2, n))
for j in xrange(5): # test randomly interleaved
order = [0]*n + [1]*n
random.shuffle(order)
lists = ([], [])
its = tee(irange(n))
for i in order:
value = its[i].next()
lists[i].append(value)
self.assertEqual(lists[0], range(n))
self.assertEqual(lists[1], range(n))
self.assertRaises(TypeError, tee)
self.assertRaises(TypeError, tee, 3)
self.assertRaises(TypeError, tee, [1,2], 'x')
def test_StopIteration(self):
self.assertRaises(StopIteration, izip().next)
@ -208,12 +253,65 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(StopIteration, islice([], None).next)
self.assertRaises(StopIteration, islice(StopNow(), None).next)
p, q = tee([])
self.assertRaises(StopIteration, p.next)
self.assertRaises(StopIteration, q.next)
p, q = tee(StopNow())
self.assertRaises(StopIteration, p.next)
self.assertRaises(StopIteration, q.next)
self.assertRaises(StopIteration, repeat(None, 0).next)
for f in (ifilter, ifilterfalse, imap, takewhile, dropwhile, starmap):
self.assertRaises(StopIteration, f(lambda x:x, []).next)
self.assertRaises(StopIteration, f(lambda x:x, StopNow()).next)
class TestGC(unittest.TestCase):
def makecycle(self, iterator, container):
container.append(iterator)
iterator.next()
del container, iterator
def test_chain(self):
a = []
self.makecycle(chain(a), a)
def test_cycle(self):
a = []
self.makecycle(cycle([a]*2), a)
def test_ifilter(self):
a = []
self.makecycle(ifilter(lambda x:True, [a]*2), a)
def test_ifilterfalse(self):
a = []
self.makecycle(ifilterfalse(lambda x:False, a), a)
def test_izip(self):
a = []
self.makecycle(izip([a]*2, [a]*3), a)
def test_imap(self):
a = []
self.makecycle(imap(lambda x:x, [a]*2), a)
def test_islice(self):
a = []
self.makecycle(islice([a]*2, None), a)
def test_starmap(self):
a = []
self.makecycle(starmap(lambda *t: t, [(a,a)]*2), a)
def test_tee(self):
a = []
p, q = t = tee([a]*2)
a += [a, p, q, t]
p.next()
del a, p, q, t
def R(seqn):
'Regular generator'
for i in seqn:
@ -290,45 +388,6 @@ def L(seqn):
'Test multiple tiers of iterators'
return chain(imap(lambda x:x, R(Ig(G(seqn)))))
class TestGC(unittest.TestCase):
def makecycle(self, iterator, container):
container.append(iterator)
iterator.next()
del container, iterator
def test_chain(self):
a = []
self.makecycle(chain(a), a)
def test_cycle(self):
a = []
self.makecycle(cycle([a]*2), a)
def test_ifilter(self):
a = []
self.makecycle(ifilter(lambda x:True, [a]*2), a)
def test_ifilterfalse(self):
a = []
self.makecycle(ifilterfalse(lambda x:False, a), a)
def test_izip(self):
a = []
self.makecycle(izip([a]*2, [a]*3), a)
def test_imap(self):
a = []
self.makecycle(imap(lambda x:x, [a]*2), a)
def test_islice(self):
a = []
self.makecycle(islice([a]*2, None), a)
def test_starmap(self):
a = []
self.makecycle(starmap(lambda *t: t, [(a,a)]*2), a)
class TestVariousIteratorArgs(unittest.TestCase):
@ -427,6 +486,16 @@ class TestVariousIteratorArgs(unittest.TestCase):
self.assertRaises(TypeError, list, dropwhile(isOdd, N(s)))
self.assertRaises(ZeroDivisionError, list, dropwhile(isOdd, E(s)))
def test_tee(self):
for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
for g in (G, I, Ig, S, L, R):
it1, it2 = tee(g(s))
self.assertEqual(list(it1), list(g(s)))
self.assertEqual(list(it2), list(g(s)))
self.assertRaises(TypeError, tee, X(s))
self.assertRaises(TypeError, list, tee(N(s))[0])
self.assertRaises(ZeroDivisionError, list, tee(E(s))[0])
class RegressionTests(unittest.TestCase):
def test_sf_793826(self):
@ -531,6 +600,17 @@ Samuele
>>> def dotproduct(vec1, vec2):
... return sum(imap(operator.mul, vec1, vec2))
>>> def flatten(listOfLists):
... return list(chain(*listOfLists))
>>> def repeatfunc(func, times=None, *args):
... "Repeat calls to func with specified arguments."
... " Example: repeatfunc(random.random)"
... if times is None:
... return starmap(func, repeat(args))
... else:
... return starmap(func, repeat(args, times))
>>> def window(seq, n=2):
... "Returns a sliding window (of width n) over data from the iterable"
... " s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ... "
@ -542,20 +622,6 @@ Samuele
... result = result[1:] + (elem,)
... yield result
>>> def tee(iterable):
... "Return two independent iterators from a single iterable"
... def gen(next, data={}, cnt=[0]):
... dpop = data.pop
... for i in count():
... if i == cnt[0]:
... item = data[i] = next()
... cnt[0] += 1
... else:
... item = dpop(i)
... yield item
... next = iter(iterable).next
... return (gen(next), gen(next))
This is not part of the examples but it tests to make sure the definitions
perform as purported.
@ -592,6 +658,17 @@ False
>>> quantify(xrange(99), lambda x: x%2==0)
50
>>> a = [[1, 2, 3], [4, 5, 6]]
>>> flatten(a)
[1, 2, 3, 4, 5, 6]
>>> list(repeatfunc(pow, 5, 2, 3))
[8, 8, 8, 8, 8]
>>> import random
>>> take(5, imap(int, repeatfunc(random.random)))
[0, 0, 0, 0, 0]
>>> list(window('abc'))
[('a', 'b'), ('b', 'c')]
@ -607,14 +684,6 @@ False
>>> dotproduct([1,2,3], [4,5,6])
32
>>> x, y = tee(chain(xrange(2,10)))
>>> list(x), list(y)
([2, 3, 4, 5, 6, 7, 8, 9], [2, 3, 4, 5, 6, 7, 8, 9])
>>> x, y = tee(chain(xrange(2,10)))
>>> zip(x, y)
[(2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7), (8, 8), (9, 9)]
"""
__test__ = {'libreftest' : libreftest}

View File

@ -73,6 +73,24 @@ Extension modules
- Implemented (?(id/name)yes|no) support in SRE (#572936).
- random.seed() with no arguments or None uses time.time() as a default
seed. Modified to match Py2.2 behavior and use fractional seconds so
that successive runs are more likely to produce different sequences.
- random.Random has a new method, getrandbits(k), which returns an int
with k random bits. This method is now an optional part of the API
for user defined generators. Any generator that defines genrandbits()
can now use randrange() for ranges with a length >= 2**53. Formerly,
randrange would return only even numbers for ranges that large (see
SF bug #812202). Generators that do not define genrandbits() now
issue a warning when randrange() is called with a range that large.
- itertools now has a new function, tee() which produces two independent
iterators from a single iterable.
- itertools.izip() with no arguments now returns an empty iterator instead
of raising a TypeError exception.
Library
-------
@ -108,21 +126,6 @@ Library
allow any iterable. Also the Set.update() has been deprecated because
it duplicates Set.union_update().
- random.seed() with no arguments or None uses time.time() as a default
seed. Modified to match Py2.2 behavior and use fractional seconds so
that successive runs are more likely to produce different sequences.
- random.Random has a new method, getrandbits(k), which returns an int
with k random bits. This method is now an optional part of the API
for user defined generators. Any generator that defines genrandbits()
can now use randrange() for ranges with a length >= 2**53. Formerly,
randrange would return only even numbers for ranges that large (see
SF bug #812202). Generators that do not define genrandbits() now
issue a warning when randrange() is called with a range that large.
- itertools.izip() with no arguments now returns an empty iterator instead
of raising a TypeError exception.
- _strptime.py now has a behind-the-scenes caching mechanism for the most
recent TimeRE instance used along with the last five unique directive
patterns. The overall module was also made more thread-safe.

View File

@ -7,6 +7,264 @@
All rights reserved.
*/
/* independent iterator object supporting the tee object ***************/
/* The tee object maintains a queue of data seen by the leading iterator
but not seen by the trailing iterator. When the leading iterator
gets data from PyIter_Next() it appends a copy to the inbasket stack.
When the trailing iterator needs data, it is popped from the outbasket
stack. If the outbasket stack is empty, then it is filled from the
inbasket (i.e. the queue is implemented using two stacks so that only
O(n) operations like append() and pop() are used to access data and
calls to reverse() never move any data element more than once).
If one of the independent iterators gets deallocated, it sets tee's
save_mode to zero so that future calls to PyIter_Next() stop getting
saved to the queue (because there is no longer a second iterator that
may need the data).
*/
typedef struct {
PyObject_HEAD
PyObject *it;
PyObject *inbasket;
PyObject *outbasket;
int save_mode;
int num_seen;
} teeobject;
typedef struct {
PyObject_HEAD
teeobject *tee;
int num_seen;
} iiobject;
static PyTypeObject ii_type;
static PyObject *
ii_next(iiobject *lz)
{
teeobject *to = lz->tee;
PyObject *result, *tmp;
if (lz->num_seen == to->num_seen) {
/* This instance is leading, use iter to get more data */
result = PyIter_Next(to->it);
if (result == NULL)
return NULL;
if (to->save_mode)
PyList_Append(to->inbasket, result);
to->num_seen++;
lz->num_seen++;
return result;
}
/* This instance is trailing, get data from the queue */
if (PyList_GET_SIZE(to->outbasket) == 0) {
/* outbasket is empty, so refill from the inbasket */
tmp = to->outbasket;
to->outbasket = to->inbasket;
to->inbasket = tmp;
PyList_Reverse(to->outbasket);
assert(PyList_GET_SIZE(to->outbasket) > 0);
}
lz->num_seen++;
return PyObject_CallMethod(to->outbasket, "pop", NULL);
}
static void
ii_dealloc(iiobject *ii)
{
PyObject_GC_UnTrack(ii);
ii->tee->save_mode = 0; /* Stop saving data */
Py_XDECREF(ii->tee);
PyObject_GC_Del(ii);
}
static int
ii_traverse(iiobject *ii, visitproc visit, void *arg)
{
if (ii->tee)
return visit((PyObject *)(ii->tee), arg);
return 0;
}
PyDoc_STRVAR(ii_doc, "Independent iterators linked to a tee() object.");
static PyTypeObject ii_type = {
PyObject_HEAD_INIT(&PyType_Type)
0, /* ob_size */
"itertools.independent_iterator", /* tp_name */
sizeof(iiobject), /* tp_basicsize */
0, /* tp_itemsize */
/* methods */
(destructor)ii_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
PyObject_GenericGetAttr, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,/* tp_flags */
ii_doc, /* tp_doc */
(traverseproc)ii_traverse, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
(iternextfunc)ii_next, /* tp_iternext */
0, /* tp_methods */
};
/* tee object **********************************************************/
static PyTypeObject tee_type;
static PyObject *
tee_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
PyObject *it = NULL;
PyObject *iterable;
PyObject *inbasket = NULL, *outbasket = NULL, *result = NULL;
teeobject *to = NULL;
int i;
if (!PyArg_UnpackTuple(args, "tee", 1, 1, &iterable))
return NULL;
it = PyObject_GetIter(iterable);
if (it == NULL) goto fail;
inbasket = PyList_New(0);
if (inbasket == NULL) goto fail;
outbasket = PyList_New(0);
if (outbasket == NULL) goto fail;
to = (teeobject *)type->tp_alloc(type, 0);
if (to == NULL) goto fail;
to->it = it;
to->inbasket = inbasket;
to->outbasket = outbasket;
to->save_mode = 1;
to->num_seen = 0;
/* create independent iterators */
result = PyTuple_New(2);
if (result == NULL) goto fail;
for (i=0 ; i<2 ; i++) {
iiobject *indep_it = PyObject_GC_New(iiobject, &ii_type);
if (indep_it == NULL) goto fail;
Py_INCREF(to);
indep_it->tee = to;
indep_it->num_seen = 0;
PyObject_GC_Track(indep_it);
PyTuple_SET_ITEM(result, i, (PyObject *)indep_it);
}
goto succeed;
fail:
Py_XDECREF(it);
Py_XDECREF(inbasket);
Py_XDECREF(outbasket);
Py_XDECREF(result);
succeed:
Py_XDECREF(to);
return result;
}
static void
tee_dealloc(teeobject *to)
{
PyObject_GC_UnTrack(to);
Py_XDECREF(to->inbasket);
Py_XDECREF(to->outbasket);
Py_XDECREF(to->it);
to->ob_type->tp_free(to);
}
static int
tee_traverse(teeobject *to, visitproc visit, void *arg)
{
int err;
if (to->it) {
err = visit(to->it, arg);
if (err)
return err;
}
if (to->inbasket) {
err = visit(to->inbasket, arg);
if (err)
return err;
}
if (to->outbasket) {
err = visit(to->outbasket, arg);
if (err)
return err;
}
return 0;
}
PyDoc_STRVAR(tee_doc,
"tee(iterable) --> (it1, it2)\n\
\n\
Split the iterable into to independent iterables.");
static PyTypeObject tee_type = {
PyObject_HEAD_INIT(NULL)
0, /* ob_size */
"itertools.tee", /* tp_name */
sizeof(teeobject), /* tp_basicsize */
0, /* tp_itemsize */
/* methods */
(destructor)tee_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_compare */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
PyObject_GenericGetAttr, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_BASETYPE, /* tp_flags */
tee_doc, /* tp_doc */
(traverseproc)tee_traverse, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
0, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
0, /* tp_init */
0, /* tp_alloc */
tee_new, /* tp_new */
PyObject_GC_Del, /* tp_free */
};
/* cycle object **********************************************************/
typedef struct {
@ -1824,6 +2082,7 @@ inititertools(void)
PyObject *m;
char *name;
PyTypeObject *typelist[] = {
&tee_type,
&cycle_type,
&dropwhile_type,
&takewhile_type,