From 08ff6822cc64497a27aba9d84b3a51b706f01221 Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Fri, 29 Feb 2008 02:21:48 +0000 Subject: [PATCH] Handle the repeat keyword argument for itertools.product(). --- Lib/test/test_itertools.py | 3 +++ Modules/itertoolsmodule.c | 32 +++++++++++++++++++++++++++----- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 500afef01e5..087570c93f1 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -296,6 +296,9 @@ class TestBasicOps(unittest.TestCase): ([range(2), range(3), range(0)], []), # last iterable with zero length ]: self.assertEqual(list(product(*args)), result) + for r in range(4): + self.assertEqual(list(product(*(args*r))), + list(product(*args, **dict(repeat=r)))) self.assertEqual(len(list(product(*[range(7)]*6))), 7**6) self.assertRaises(TypeError, product, range(6), None) argtypes = ['', 'abc', '', xrange(0), xrange(4), dict(a=1, b=2, c=3), diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index f29077a3af1..e3d8bd8739f 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -1782,17 +1782,32 @@ static PyObject * product_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { productobject *lz; - Py_ssize_t npools; + Py_ssize_t nargs, npools, repeat=1; PyObject *pools = NULL; Py_ssize_t *maxvec = NULL; Py_ssize_t *indices = NULL; Py_ssize_t i; - if (type == &product_type && !_PyArg_NoKeywords("product()", kwds)) - return NULL; + if (kwds != NULL) { + char *kwlist[] = {"repeat", 0}; + PyObject *tmpargs = PyTuple_New(0); + if (tmpargs == NULL) + return NULL; + if (!PyArg_ParseTupleAndKeywords(tmpargs, kwds, "|n:product", kwlist, &repeat)) { + Py_DECREF(tmpargs); + return NULL; + } + Py_DECREF(tmpargs); + if (repeat < 0) { + PyErr_SetString(PyExc_ValueError, + "repeat argument cannot be negative"); + return NULL; + } + } assert(PyTuple_Check(args)); - npools = PyTuple_GET_SIZE(args); + nargs = (repeat == 0) ? 0 : PyTuple_GET_SIZE(args); + npools = nargs * repeat; maxvec = PyMem_Malloc(npools * sizeof(Py_ssize_t)); indices = PyMem_Malloc(npools * sizeof(Py_ssize_t)); @@ -1805,7 +1820,7 @@ product_new(PyTypeObject *type, PyObject *args, PyObject *kwds) if (pools == NULL) goto error; - for (i=0; i < npools; ++i) { + for (i=0; i < nargs ; ++i) { PyObject *item = PyTuple_GET_ITEM(args, i); PyObject *pool = PySequence_Tuple(item); if (pool == NULL) @@ -1815,6 +1830,13 @@ product_new(PyTypeObject *type, PyObject *args, PyObject *kwds) maxvec[i] = PyTuple_GET_SIZE(pool); indices[i] = 0; } + for ( ; i < npools; ++i) { + PyObject *pool = PyTuple_GET_ITEM(pools, i - nargs); + Py_INCREF(pool); + PyTuple_SET_ITEM(pools, i, pool); + maxvec[i] = maxvec[i - nargs]; + indices[i] = 0; + } /* create productobject structure */ lz = (productobject *)type->tp_alloc(type, 0);