gh-123884 Tee of tee was not producing n independent iterators (gh-124490)

This commit is contained in:
Raymond Hettinger 2024-09-25 13:38:05 -07:00 committed by GitHub
parent fb6bd31cb7
commit 909c6f7189
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 89 additions and 87 deletions

View File

@ -691,25 +691,36 @@ loops that truncate the stream.
def tee(iterable, n=2): def tee(iterable, n=2):
if n < 0: if n < 0:
raise ValueError('n must be >= 0') raise ValueError
iterator = iter(iterable) if n == 0:
shared_link = [None, None] return ()
return tuple(_tee(iterator, shared_link) for _ in range(n)) iterator = _tee(iterable)
result = [iterator]
for _ in range(n - 1):
result.append(_tee(iterator))
return tuple(result)
def _tee(iterator, link): class _tee:
try:
while True:
if link[1] is None:
link[0] = next(iterator)
link[1] = [None, None]
value, link = link
yield value
except StopIteration:
return
Once a :func:`tee` has been created, the original *iterable* should not be def __init__(self, iterable):
used anywhere else; otherwise, the *iterable* could get advanced without it = iter(iterable)
the tee objects being informed. if isinstance(it, _tee):
self.iterator = it.iterator
self.link = it.link
else:
self.iterator = it
self.link = [None, None]
def __iter__(self):
return self
def __next__(self):
link = self.link
if link[1] is None:
link[0] = next(self.iterator)
link[1] = [None, None]
value, self.link = link
return value
When the input *iterable* is already a tee iterator object, all When the input *iterable* is already a tee iterator object, all
members of the return tuple are constructed as if they had been members of the return tuple are constructed as if they had been

View File

@ -604,7 +604,6 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) {
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__classdictcell__)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__classdictcell__));
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__complex__)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__complex__));
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__contains__)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__contains__));
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__copy__));
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__ctypes_from_outparam__)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__ctypes_from_outparam__));
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__del__)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__del__));
_PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__delattr__)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(__delattr__));

View File

@ -93,7 +93,6 @@ struct _Py_global_strings {
STRUCT_FOR_ID(__classdictcell__) STRUCT_FOR_ID(__classdictcell__)
STRUCT_FOR_ID(__complex__) STRUCT_FOR_ID(__complex__)
STRUCT_FOR_ID(__contains__) STRUCT_FOR_ID(__contains__)
STRUCT_FOR_ID(__copy__)
STRUCT_FOR_ID(__ctypes_from_outparam__) STRUCT_FOR_ID(__ctypes_from_outparam__)
STRUCT_FOR_ID(__del__) STRUCT_FOR_ID(__del__)
STRUCT_FOR_ID(__delattr__) STRUCT_FOR_ID(__delattr__)

View File

@ -602,7 +602,6 @@ extern "C" {
INIT_ID(__classdictcell__), \ INIT_ID(__classdictcell__), \
INIT_ID(__complex__), \ INIT_ID(__complex__), \
INIT_ID(__contains__), \ INIT_ID(__contains__), \
INIT_ID(__copy__), \
INIT_ID(__ctypes_from_outparam__), \ INIT_ID(__ctypes_from_outparam__), \
INIT_ID(__del__), \ INIT_ID(__del__), \
INIT_ID(__delattr__), \ INIT_ID(__delattr__), \

View File

@ -172,10 +172,6 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) {
_PyUnicode_InternStatic(interp, &string); _PyUnicode_InternStatic(interp, &string);
assert(_PyUnicode_CheckConsistency(string, 1)); assert(_PyUnicode_CheckConsistency(string, 1));
assert(PyUnicode_GET_LENGTH(string) != 1); assert(PyUnicode_GET_LENGTH(string) != 1);
string = &_Py_ID(__copy__);
_PyUnicode_InternStatic(interp, &string);
assert(_PyUnicode_CheckConsistency(string, 1));
assert(PyUnicode_GET_LENGTH(string) != 1);
string = &_Py_ID(__ctypes_from_outparam__); string = &_Py_ID(__ctypes_from_outparam__);
_PyUnicode_InternStatic(interp, &string); _PyUnicode_InternStatic(interp, &string);
assert(_PyUnicode_CheckConsistency(string, 1)); assert(_PyUnicode_CheckConsistency(string, 1));

View File

@ -1249,10 +1249,11 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(len(result), n) self.assertEqual(len(result), n)
self.assertEqual([list(x) for x in result], [list('abc')]*n) self.assertEqual([list(x) for x in result], [list('abc')]*n)
# tee pass-through to copyable iterator # tee objects are independent (see bug gh-123884)
a, b = tee('abc') a, b = tee('abc')
c, d = tee(a) c, d = tee(a)
self.assertTrue(a is c) e, f = tee(c)
self.assertTrue(len({a, b, c, d, e, f}) == 6)
# test tee_new # test tee_new
t1, t2 = tee('abc') t1, t2 = tee('abc')
@ -1759,21 +1760,36 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
def tee(iterable, n=2): def tee(iterable, n=2):
if n < 0: if n < 0:
raise ValueError('n must be >= 0') raise ValueError
iterator = iter(iterable) if n == 0:
shared_link = [None, None] return ()
return tuple(_tee(iterator, shared_link) for _ in range(n)) iterator = _tee(iterable)
result = [iterator]
for _ in range(n - 1):
result.append(_tee(iterator))
return tuple(result)
def _tee(iterator, link): class _tee:
try:
while True: def __init__(self, iterable):
if link[1] is None: it = iter(iterable)
link[0] = next(iterator) if isinstance(it, _tee):
link[1] = [None, None] self.iterator = it.iterator
value, link = link self.link = it.link
yield value else:
except StopIteration: self.iterator = it
return self.link = [None, None]
def __iter__(self):
return self
def __next__(self):
link = self.link
if link[1] is None:
link[0] = next(self.iterator)
link[1] = [None, None]
value, self.link = link
return value
# End tee() recipe ############################################# # End tee() recipe #############################################
@ -1819,12 +1835,10 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
self.assertRaises(TypeError, tee, [1,2], 'x') self.assertRaises(TypeError, tee, [1,2], 'x')
self.assertRaises(TypeError, tee, [1,2], 3, 'x') self.assertRaises(TypeError, tee, [1,2], 3, 'x')
# Tests not applicable to the tee() recipe # tee object should be instantiable
if False: a, b = tee('abc')
# tee object should be instantiable c = type(a)('def')
a, b = tee('abc') self.assertEqual(list(c), list('def'))
c = type(a)('def')
self.assertEqual(list(c), list('def'))
# test long-lagged and multi-way split # test long-lagged and multi-way split
a, b, c = tee(range(2000), 3) a, b, c = tee(range(2000), 3)
@ -1845,21 +1859,19 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
self.assertEqual(len(result), n) self.assertEqual(len(result), n)
self.assertEqual([list(x) for x in result], [list('abc')]*n) self.assertEqual([list(x) for x in result], [list('abc')]*n)
# tee objects are independent (see bug gh-123884)
a, b = tee('abc')
c, d = tee(a)
e, f = tee(c)
self.assertTrue(len({a, b, c, d, e, f}) == 6)
# Tests not applicable to the tee() recipe # test tee_new
if False: t1, t2 = tee('abc')
# tee pass-through to copyable iterator tnew = type(t1)
a, b = tee('abc') self.assertRaises(TypeError, tnew)
c, d = tee(a) self.assertRaises(TypeError, tnew, 10)
self.assertTrue(a is c) t3 = tnew(t1)
self.assertTrue(list(t1) == list(t2) == list(t3) == list('abc'))
# test tee_new
t1, t2 = tee('abc')
tnew = type(t1)
self.assertRaises(TypeError, tnew)
self.assertRaises(TypeError, tnew, 10)
t3 = tnew(t1)
self.assertTrue(list(t1) == list(t2) == list(t3) == list('abc'))
# test that tee objects are weak referencable # test that tee objects are weak referencable
a, b = tee(range(10)) a, b = tee(range(10))

View File

@ -0,0 +1,4 @@
Fixed bug in itertools.tee() handling of other tee inputs (a tee in a tee).
The output now has the promised *n* independent new iterators. Formerly,
the first iterator was identical (not independent) to the input iterator.
This would sometimes give surprising results.

View File

@ -1036,7 +1036,7 @@ itertools_tee_impl(PyObject *module, PyObject *iterable, Py_ssize_t n)
/*[clinic end generated code: output=1c64519cd859c2f0 input=c99a1472c425d66d]*/ /*[clinic end generated code: output=1c64519cd859c2f0 input=c99a1472c425d66d]*/
{ {
Py_ssize_t i; Py_ssize_t i;
PyObject *it, *copyable, *copyfunc, *result; PyObject *it, *to, *result;
if (n < 0) { if (n < 0) {
PyErr_SetString(PyExc_ValueError, "n must be >= 0"); PyErr_SetString(PyExc_ValueError, "n must be >= 0");
@ -1053,41 +1053,23 @@ itertools_tee_impl(PyObject *module, PyObject *iterable, Py_ssize_t n)
return NULL; return NULL;
} }
if (PyObject_GetOptionalAttr(it, &_Py_ID(__copy__), &copyfunc) < 0) { itertools_state *state = get_module_state(module);
Py_DECREF(it); to = tee_fromiterable(state, it);
Py_DECREF(it);
if (to == NULL) {
Py_DECREF(result); Py_DECREF(result);
return NULL; return NULL;
} }
if (copyfunc != NULL) {
copyable = it;
}
else {
itertools_state *state = get_module_state(module);
copyable = tee_fromiterable(state, it);
Py_DECREF(it);
if (copyable == NULL) {
Py_DECREF(result);
return NULL;
}
copyfunc = PyObject_GetAttr(copyable, &_Py_ID(__copy__));
if (copyfunc == NULL) {
Py_DECREF(copyable);
Py_DECREF(result);
return NULL;
}
}
PyTuple_SET_ITEM(result, 0, copyable); PyTuple_SET_ITEM(result, 0, to);
for (i = 1; i < n; i++) { for (i = 1; i < n; i++) {
copyable = _PyObject_CallNoArgs(copyfunc); to = tee_copy((teeobject *)to, NULL);
if (copyable == NULL) { if (to == NULL) {
Py_DECREF(copyfunc);
Py_DECREF(result); Py_DECREF(result);
return NULL; return NULL;
} }
PyTuple_SET_ITEM(result, i, copyable); PyTuple_SET_ITEM(result, i, to);
} }
Py_DECREF(copyfunc);
return result; return result;
} }