bpo-40636: PEP 618: add strict parameter to zip() (GH-20921)

zip() now supports PEP 618's strict parameter, which raises a
ValueError if the arguments are exhausted at different lengths.
Patch by Brandt Bucher.

Co-authored-by: Brandt Bucher <brandtbucher@gmail.com>
Co-authored-by: Ram Rachum <ram@rachum.com>
This commit is contained in:
Guido van Rossum 2020-06-19 03:16:57 -07:00 committed by GitHub
parent 37bb289556
commit 310f6aa7db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 238 additions and 8 deletions

View File

@ -1521,6 +1521,14 @@ class BuiltinTest(unittest.TestCase):
self.assertRaises(TypeError, vars, 42) self.assertRaises(TypeError, vars, 42)
self.assertEqual(vars(self.C_get_vars()), {'a':2}) self.assertEqual(vars(self.C_get_vars()), {'a':2})
def iter_error(self, iterable, error):
"""Collect `iterable` into a list, catching an expected `error`."""
items = []
with self.assertRaises(error):
for item in iterable:
items.append(item)
return items
def test_zip(self): def test_zip(self):
a = (1, 2, 3) a = (1, 2, 3)
b = (4, 5, 6) b = (4, 5, 6)
@ -1573,6 +1581,66 @@ class BuiltinTest(unittest.TestCase):
z1 = zip(a, b) z1 = zip(a, b)
self.check_iter_pickle(z1, t, proto) self.check_iter_pickle(z1, t, proto)
def test_zip_pickle_strict(self):
a = (1, 2, 3)
b = (4, 5, 6)
t = [(1, 4), (2, 5), (3, 6)]
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
z1 = zip(a, b, strict=True)
self.check_iter_pickle(z1, t, proto)
def test_zip_pickle_strict_fail(self):
a = (1, 2, 3)
b = (4, 5, 6, 7)
t = [(1, 4), (2, 5), (3, 6)]
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
z1 = zip(a, b, strict=True)
z2 = pickle.loads(pickle.dumps(z1, proto))
self.assertEqual(self.iter_error(z1, ValueError), t)
self.assertEqual(self.iter_error(z2, ValueError), t)
def test_zip_pickle_stability(self):
# Pickles of zip((1, 2, 3), (4, 5, 6)) dumped from 3.9:
pickles = [
b'citertools\nizip\np0\n(c__builtin__\niter\np1\n((I1\nI2\nI3\ntp2\ntp3\nRp4\nI0\nbg1\n((I4\nI5\nI6\ntp5\ntp6\nRp7\nI0\nbtp8\nRp9\n.',
b'citertools\nizip\nq\x00(c__builtin__\niter\nq\x01((K\x01K\x02K\x03tq\x02tq\x03Rq\x04K\x00bh\x01((K\x04K\x05K\x06tq\x05tq\x06Rq\x07K\x00btq\x08Rq\t.',
b'\x80\x02citertools\nizip\nq\x00c__builtin__\niter\nq\x01K\x01K\x02K\x03\x87q\x02\x85q\x03Rq\x04K\x00bh\x01K\x04K\x05K\x06\x87q\x05\x85q\x06Rq\x07K\x00b\x86q\x08Rq\t.',
b'\x80\x03cbuiltins\nzip\nq\x00cbuiltins\niter\nq\x01K\x01K\x02K\x03\x87q\x02\x85q\x03Rq\x04K\x00bh\x01K\x04K\x05K\x06\x87q\x05\x85q\x06Rq\x07K\x00b\x86q\x08Rq\t.',
b'\x80\x04\x95L\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x03zip\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94K\x01K\x02K\x03\x87\x94\x85\x94R\x94K\x00bh\x05K\x04K\x05K\x06\x87\x94\x85\x94R\x94K\x00b\x86\x94R\x94.',
b'\x80\x05\x95L\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x03zip\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94K\x01K\x02K\x03\x87\x94\x85\x94R\x94K\x00bh\x05K\x04K\x05K\x06\x87\x94\x85\x94R\x94K\x00b\x86\x94R\x94.',
]
for protocol, dump in enumerate(pickles):
z1 = zip((1, 2, 3), (4, 5, 6))
z2 = zip((1, 2, 3), (4, 5, 6), strict=False)
z3 = pickle.loads(dump)
l3 = list(z3)
self.assertEqual(type(z3), zip)
self.assertEqual(pickle.dumps(z1, protocol), dump)
self.assertEqual(pickle.dumps(z2, protocol), dump)
self.assertEqual(list(z1), l3)
self.assertEqual(list(z2), l3)
def test_zip_pickle_strict_stability(self):
# Pickles of zip((1, 2, 3), (4, 5), strict=True) dumped from 3.10:
pickles = [
b'citertools\nizip\np0\n(c__builtin__\niter\np1\n((I1\nI2\nI3\ntp2\ntp3\nRp4\nI0\nbg1\n((I4\nI5\ntp5\ntp6\nRp7\nI0\nbtp8\nRp9\nI01\nb.',
b'citertools\nizip\nq\x00(c__builtin__\niter\nq\x01((K\x01K\x02K\x03tq\x02tq\x03Rq\x04K\x00bh\x01((K\x04K\x05tq\x05tq\x06Rq\x07K\x00btq\x08Rq\tI01\nb.',
b'\x80\x02citertools\nizip\nq\x00c__builtin__\niter\nq\x01K\x01K\x02K\x03\x87q\x02\x85q\x03Rq\x04K\x00bh\x01K\x04K\x05\x86q\x05\x85q\x06Rq\x07K\x00b\x86q\x08Rq\t\x88b.',
b'\x80\x03cbuiltins\nzip\nq\x00cbuiltins\niter\nq\x01K\x01K\x02K\x03\x87q\x02\x85q\x03Rq\x04K\x00bh\x01K\x04K\x05\x86q\x05\x85q\x06Rq\x07K\x00b\x86q\x08Rq\t\x88b.',
b'\x80\x04\x95L\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x03zip\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94K\x01K\x02K\x03\x87\x94\x85\x94R\x94K\x00bh\x05K\x04K\x05\x86\x94\x85\x94R\x94K\x00b\x86\x94R\x94\x88b.',
b'\x80\x05\x95L\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x03zip\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94K\x01K\x02K\x03\x87\x94\x85\x94R\x94K\x00bh\x05K\x04K\x05\x86\x94\x85\x94R\x94K\x00b\x86\x94R\x94\x88b.',
]
a = (1, 2, 3)
b = (4, 5)
t = [(1, 4), (2, 5)]
for protocol, dump in enumerate(pickles):
z1 = zip(a, b, strict=True)
z2 = pickle.loads(dump)
self.assertEqual(pickle.dumps(z1, protocol), dump)
self.assertEqual(type(z2), zip)
self.assertEqual(self.iter_error(z1, ValueError), t)
self.assertEqual(self.iter_error(z2, ValueError), t)
def test_zip_bad_iterable(self): def test_zip_bad_iterable(self):
exception = TypeError() exception = TypeError()
@ -1585,6 +1653,88 @@ class BuiltinTest(unittest.TestCase):
self.assertIs(cm.exception, exception) self.assertIs(cm.exception, exception)
def test_zip_strict(self):
self.assertEqual(tuple(zip((1, 2, 3), 'abc', strict=True)),
((1, 'a'), (2, 'b'), (3, 'c')))
self.assertRaises(ValueError, tuple,
zip((1, 2, 3, 4), 'abc', strict=True))
self.assertRaises(ValueError, tuple,
zip((1, 2), 'abc', strict=True))
self.assertRaises(ValueError, tuple,
zip((1, 2), (1, 2), 'abc', strict=True))
def test_zip_strict_iterators(self):
x = iter(range(5))
y = [0]
z = iter(range(5))
self.assertRaises(ValueError, list,
(zip(x, y, z, strict=True)))
self.assertEqual(next(x), 2)
self.assertEqual(next(z), 1)
def test_zip_strict_error_handling(self):
class Error(Exception):
pass
class Iter:
def __init__(self, size):
self.size = size
def __iter__(self):
return self
def __next__(self):
self.size -= 1
if self.size < 0:
raise Error
return self.size
l1 = self.iter_error(zip("AB", Iter(1), strict=True), Error)
self.assertEqual(l1, [("A", 0)])
l2 = self.iter_error(zip("AB", Iter(2), "A", strict=True), ValueError)
self.assertEqual(l2, [("A", 1, "A")])
l3 = self.iter_error(zip("AB", Iter(2), "ABC", strict=True), Error)
self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
l4 = self.iter_error(zip("AB", Iter(3), strict=True), ValueError)
self.assertEqual(l4, [("A", 2), ("B", 1)])
l5 = self.iter_error(zip(Iter(1), "AB", strict=True), Error)
self.assertEqual(l5, [(0, "A")])
l6 = self.iter_error(zip(Iter(2), "A", strict=True), ValueError)
self.assertEqual(l6, [(1, "A")])
l7 = self.iter_error(zip(Iter(2), "ABC", strict=True), Error)
self.assertEqual(l7, [(1, "A"), (0, "B")])
l8 = self.iter_error(zip(Iter(3), "AB", strict=True), ValueError)
self.assertEqual(l8, [(2, "A"), (1, "B")])
def test_zip_strict_error_handling_stopiteration(self):
class Iter:
def __init__(self, size):
self.size = size
def __iter__(self):
return self
def __next__(self):
self.size -= 1
if self.size < 0:
raise StopIteration
return self.size
l1 = self.iter_error(zip("AB", Iter(1), strict=True), ValueError)
self.assertEqual(l1, [("A", 0)])
l2 = self.iter_error(zip("AB", Iter(2), "A", strict=True), ValueError)
self.assertEqual(l2, [("A", 1, "A")])
l3 = self.iter_error(zip("AB", Iter(2), "ABC", strict=True), ValueError)
self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
l4 = self.iter_error(zip("AB", Iter(3), strict=True), ValueError)
self.assertEqual(l4, [("A", 2), ("B", 1)])
l5 = self.iter_error(zip(Iter(1), "AB", strict=True), ValueError)
self.assertEqual(l5, [(0, "A")])
l6 = self.iter_error(zip(Iter(2), "A", strict=True), ValueError)
self.assertEqual(l6, [(1, "A")])
l7 = self.iter_error(zip(Iter(2), "ABC", strict=True), ValueError)
self.assertEqual(l7, [(1, "A"), (0, "B")])
l8 = self.iter_error(zip(Iter(3), "AB", strict=True), ValueError)
self.assertEqual(l8, [(2, "A"), (1, "B")])
def test_format(self): def test_format(self):
# Test the basic machinery of the format() builtin. Don't test # Test the basic machinery of the format() builtin. Don't test
# the specifics of the various formatters # the specifics of the various formatters

View File

@ -0,0 +1,3 @@
:func:`zip` now supports :pep:`618`'s ``strict`` parameter, which raises a
:exc:`ValueError` if the arguments are exhausted at different lengths.
Patch by Brandt Bucher.

View File

@ -2517,9 +2517,10 @@ builtin_issubclass_impl(PyObject *module, PyObject *cls,
typedef struct { typedef struct {
PyObject_HEAD PyObject_HEAD
Py_ssize_t tuplesize; Py_ssize_t tuplesize;
PyObject *ittuple; /* tuple of iterators */ PyObject *ittuple; /* tuple of iterators */
PyObject *result; PyObject *result;
int strict;
} zipobject; } zipobject;
static PyObject * static PyObject *
@ -2530,9 +2531,21 @@ zip_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
PyObject *ittuple; /* tuple of iterators */ PyObject *ittuple; /* tuple of iterators */
PyObject *result; PyObject *result;
Py_ssize_t tuplesize; Py_ssize_t tuplesize;
int strict = 0;
if (type == &PyZip_Type && !_PyArg_NoKeywords("zip", kwds)) if (kwds) {
return NULL; PyObject *empty = PyTuple_New(0);
if (empty == NULL) {
return NULL;
}
static char *kwlist[] = {"strict", NULL};
int parsed = PyArg_ParseTupleAndKeywords(
empty, kwds, "|$p:zip", kwlist, &strict);
Py_DECREF(empty);
if (!parsed) {
return NULL;
}
}
/* args must be a tuple */ /* args must be a tuple */
assert(PyTuple_Check(args)); assert(PyTuple_Check(args));
@ -2573,6 +2586,7 @@ zip_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
lz->ittuple = ittuple; lz->ittuple = ittuple;
lz->tuplesize = tuplesize; lz->tuplesize = tuplesize;
lz->result = result; lz->result = result;
lz->strict = strict;
return (PyObject *)lz; return (PyObject *)lz;
} }
@ -2613,6 +2627,9 @@ zip_next(zipobject *lz)
item = (*Py_TYPE(it)->tp_iternext)(it); item = (*Py_TYPE(it)->tp_iternext)(it);
if (item == NULL) { if (item == NULL) {
Py_DECREF(result); Py_DECREF(result);
if (lz->strict) {
goto check;
}
return NULL; return NULL;
} }
olditem = PyTuple_GET_ITEM(result, i); olditem = PyTuple_GET_ITEM(result, i);
@ -2628,28 +2645,85 @@ zip_next(zipobject *lz)
item = (*Py_TYPE(it)->tp_iternext)(it); item = (*Py_TYPE(it)->tp_iternext)(it);
if (item == NULL) { if (item == NULL) {
Py_DECREF(result); Py_DECREF(result);
if (lz->strict) {
goto check;
}
return NULL; return NULL;
} }
PyTuple_SET_ITEM(result, i, item); PyTuple_SET_ITEM(result, i, item);
} }
} }
return result; return result;
check:
if (PyErr_Occurred()) {
if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
// next() on argument i raised an exception (not StopIteration)
return NULL;
}
PyErr_Clear();
}
if (i) {
// ValueError: zip() argument 2 is shorter than argument 1
// ValueError: zip() argument 3 is shorter than arguments 1-2
const char* plural = i == 1 ? " " : "s 1-";
return PyErr_Format(PyExc_ValueError,
"zip() argument %d is shorter than argument%s%d",
i + 1, plural, i);
}
for (i = 1; i < tuplesize; i++) {
it = PyTuple_GET_ITEM(lz->ittuple, i);
item = (*Py_TYPE(it)->tp_iternext)(it);
if (item) {
Py_DECREF(item);
const char* plural = i == 1 ? " " : "s 1-";
return PyErr_Format(PyExc_ValueError,
"zip() argument %d is longer than argument%s%d",
i + 1, plural, i);
}
if (PyErr_Occurred()) {
if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
// next() on argument i raised an exception (not StopIteration)
return NULL;
}
PyErr_Clear();
}
// Argument i is exhausted. So far so good...
}
// All arguments are exhausted. Success!
return NULL;
} }
static PyObject * static PyObject *
zip_reduce(zipobject *lz, PyObject *Py_UNUSED(ignored)) zip_reduce(zipobject *lz, PyObject *Py_UNUSED(ignored))
{ {
/* Just recreate the zip with the internal iterator tuple */ /* Just recreate the zip with the internal iterator tuple */
return Py_BuildValue("OO", Py_TYPE(lz), lz->ittuple); if (lz->strict) {
return PyTuple_Pack(3, Py_TYPE(lz), lz->ittuple, Py_True);
}
return PyTuple_Pack(2, Py_TYPE(lz), lz->ittuple);
}
PyDoc_STRVAR(setstate_doc, "Set state information for unpickling.");
static PyObject *
zip_setstate(zipobject *lz, PyObject *state)
{
int strict = PyObject_IsTrue(state);
if (strict < 0) {
return NULL;
}
lz->strict = strict;
Py_RETURN_NONE;
} }
static PyMethodDef zip_methods[] = { static PyMethodDef zip_methods[] = {
{"__reduce__", (PyCFunction)zip_reduce, METH_NOARGS, reduce_doc}, {"__reduce__", (PyCFunction)zip_reduce, METH_NOARGS, reduce_doc},
{NULL, NULL} /* sentinel */ {"__setstate__", (PyCFunction)zip_setstate, METH_O, setstate_doc},
{NULL} /* sentinel */
}; };
PyDoc_STRVAR(zip_doc, PyDoc_STRVAR(zip_doc,
"zip(*iterables) --> A zip object yielding tuples until an input is exhausted.\n\ "zip(*iterables, strict=False) --> Yield tuples until an input is exhausted.\n\
\n\ \n\
>>> list(zip('abcdefg', range(3), range(4)))\n\ >>> list(zip('abcdefg', range(3), range(4)))\n\
[('a', 0, 0), ('b', 1, 1), ('c', 2, 2)]\n\ [('a', 0, 0), ('b', 1, 1), ('c', 2, 2)]\n\
@ -2657,7 +2731,10 @@ PyDoc_STRVAR(zip_doc,
The zip object yields n-length tuples, where n is the number of iterables\n\ The zip object yields n-length tuples, where n is the number of iterables\n\
passed as positional arguments to zip(). The i-th element in every tuple\n\ passed as positional arguments to zip(). The i-th element in every tuple\n\
comes from the i-th iterable argument to zip(). This continues until the\n\ comes from the i-th iterable argument to zip(). This continues until the\n\
shortest argument is exhausted."); shortest argument is exhausted.\n\
\n\
If strict is true and one of the arguments is exhausted before the others,\n\
raise a ValueError.");
PyTypeObject PyZip_Type = { PyTypeObject PyZip_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) PyVarObject_HEAD_INIT(&PyType_Type, 0)