diff --git a/Include/Python.h b/Include/Python.h index f440a3a8635..e3addc704a0 100644 --- a/Include/Python.h +++ b/Include/Python.h @@ -91,6 +91,7 @@ #include "tupleobject.h" #include "listobject.h" #include "dictobject.h" +#include "enumobject.h" #include "methodobject.h" #include "moduleobject.h" #include "funcobject.h" diff --git a/Include/enumobject.h b/Include/enumobject.h new file mode 100644 index 00000000000..df20fb0b2c6 --- /dev/null +++ b/Include/enumobject.h @@ -0,0 +1,16 @@ +#ifndef Py_ENUMOBJECT_H +#define Py_ENUMOBJECT_H + +/* Enumerate Object */ + +#ifdef __cplusplus +extern "C" { +#endif + +extern DL_IMPORT(PyTypeObject) PyEnum_Type; + +#ifdef __cplusplus +} +#endif + +#endif /* !Py_ENUMOBJECT_H */ diff --git a/Lib/test/test_enumerate.py b/Lib/test/test_enumerate.py new file mode 100644 index 00000000000..b0d442e0f90 --- /dev/null +++ b/Lib/test/test_enumerate.py @@ -0,0 +1,118 @@ +from __future__ import generators +import unittest + +import test_support + +seq, res = 'abc', [(0,'a'), (1,'b'), (2,'c')] + +class G: + 'Sequence using __getitem__' + def __init__(self, seqn): + self.seqn = seqn + def __getitem__(self, i): + return self.seqn[i] + +class I: + 'Sequence using iterator protocol' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + return self + def next(self): + if self.i >= len(self.seqn): raise StopIteration + v = self.seqn[self.i] + self.i += 1 + return v + +class Ig: + 'Sequence using iterator protocol defined with a generator' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + for val in self.seqn: + yield val + +class X: + 'Missing __getitem__ and __iter__' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def next(self): + if self.i >= len(self.seqn): raise StopIteration + v = self.seqn[self.i] + self.i += 1 + return v + +class E: + 'Test propagation of exceptions' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + return self + def next(self): + 3/0 + +class N: + 'Iterator missing next()' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + return self + +class EnumerateTestCase(unittest.TestCase): + + enum = enumerate + + def test_basicfunction(self): + self.assertEqual(type(self.enum(seq)), self.enum) + e = self.enum(seq) + self.assertEqual(iter(e), e) + self.assertEqual(list(self.enum(seq)), res) + self.enum.__doc__ + + def test_getitemseqn(self): + self.assertEqual(list(self.enum(G(seq))), res) + e = self.enum(G('')) + self.assertRaises(StopIteration, e.next) + + def test_iteratorseqn(self): + self.assertEqual(list(self.enum(I(seq))), res) + e = self.enum(I('')) + self.assertRaises(StopIteration, e.next) + + def test_iteratorgenerator(self): + self.assertEqual(list(self.enum(Ig(seq))), res) + e = self.enum(Ig('')) + self.assertRaises(StopIteration, e.next) + + def test_noniterable(self): + self.assertRaises(TypeError, self.enum, X(seq)) + + def test_illformediterable(self): + self.assertRaises(TypeError, list, self.enum(N(seq))) + + def test_exception_propagation(self): + self.assertRaises(ZeroDivisionError, list, self.enum(E(seq))) + +class MyEnum(enumerate): + pass + +class SubclassTestCase(EnumerateTestCase): + + enum = MyEnum + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(EnumerateTestCase)) + suite.addTest(unittest.makeSuite(SubclassTestCase)) + return suite + +def test_main(): + test_support.run_suite(suite()) + +if __name__ == "__main__": + test_main() diff --git a/Makefile.pre.in b/Makefile.pre.in index 9ad1c0a2d55..9ec4fa66974 100644 --- a/Makefile.pre.in +++ b/Makefile.pre.in @@ -257,6 +257,7 @@ OBJECT_OBJS= \ Objects/cobject.o \ Objects/complexobject.o \ Objects/descrobject.o \ + Objects/enumobject.o \ Objects/fileobject.o \ Objects/floatobject.o \ Objects/frameobject.o \ @@ -443,6 +444,7 @@ PYTHON_HEADERS= \ Include/complexobject.h \ Include/descrobject.h \ Include/dictobject.h \ + Include/enumobject.h \ Include/fileobject.h \ Include/floatobject.h \ Include/funcobject.h \ diff --git a/Misc/NEWS b/Misc/NEWS index c63db58bddb..d83d7b29f87 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -6,6 +6,10 @@ Type/class unification and new-style classes Core and builtins +- New builtin function enumerate(x), from PEP 279. Example: + enumerate("abc") is an iterator returning (0,"a"), (1,"b"), (2,"c"). + The argument can be an arbitrary iterable object. + - The assert statement no longer tests __debug__ at runtime. This means that assert statements cannot be disabled by assigning a false value to __debug__. diff --git a/Objects/enumobject.c b/Objects/enumobject.c new file mode 100644 index 00000000000..19649564155 --- /dev/null +++ b/Objects/enumobject.c @@ -0,0 +1,139 @@ +/* enumerate object */ + +#include "Python.h" + +typedef struct { + PyObject_HEAD + long en_index; /* current index of enumeration */ + PyObject* en_sit; /* secondary iterator of enumeration */ +} enumobject; + +PyTypeObject PyEnum_Type; + +static PyObject * +enum_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + enumobject *en; + PyObject *seq = NULL; + static char *kwlist[] = {"sequence", 0}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:enumerate", kwlist, + &seq)) + return NULL; + + en = (enumobject *)type->tp_alloc(type, 0); + if (en == NULL) + return NULL; + en->en_index = 0; + en->en_sit = PyObject_GetIter(seq); + if (en->en_sit == NULL) { + Py_DECREF(en); + return NULL; + } + return (PyObject *)en; +} + +static void +enum_dealloc(enumobject *en) +{ + PyObject_GC_UnTrack(en); + Py_XDECREF(en->en_sit); + en->ob_type->tp_free(en); +} + +static int +enum_traverse(enumobject *en, visitproc visit, void *arg) +{ + if (en->en_sit) + return visit(en->en_sit, arg); + return 0; +} + +static PyObject * +enum_next(enumobject *en) +{ + PyObject *result; + PyObject *next_index; + + PyObject *next_item = PyIter_Next(en->en_sit); + if (next_item == NULL) + return NULL; + + result = PyTuple_New(2); + if (result == NULL) { + Py_DECREF(next_item); + return NULL; + } + + next_index = PyInt_FromLong(en->en_index++); + if (next_index == NULL) { + Py_DECREF(next_item); + Py_DECREF(result); + return NULL; + } + + PyTuple_SET_ITEM(result, 0, next_index); + PyTuple_SET_ITEM(result, 1, next_item); + return result; +} + +static PyObject * +enum_getiter(PyObject *en) +{ + Py_INCREF(en); + return en; +} + +static PyMethodDef enum_methods[] = { + {"next", (PyCFunction)enum_next, METH_NOARGS, + "return the next (index, value) pair, or raise StopIteration"}, + {NULL, NULL} /* sentinel */ +}; + +static char enum_doc[] = + "enumerate(iterable) -> create an enumerating-iterator"; + +PyTypeObject PyEnum_Type = { + PyObject_HEAD_INIT(&PyType_Type) + 0, /* ob_size */ + "enumerate", /* tp_name */ + sizeof(enumobject), /* tp_basicsize */ + 0, /* tp_itemsize */ + /* methods */ + (destructor)enum_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + PyObject_GenericGetAttr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_BASETYPE, /* tp_flags */ + enum_doc, /* tp_doc */ + (traverseproc)enum_traverse, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + (getiterfunc)enum_getiter, /* tp_iter */ + (iternextfunc)enum_next, /* tp_iternext */ + enum_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + PyType_GenericAlloc, /* tp_alloc */ + enum_new, /* tp_new */ + PyObject_GC_Del, /* tp_free */ +}; diff --git a/Python/bltinmodule.c b/Python/bltinmodule.c index 680152d510b..06d86167aa7 100644 --- a/Python/bltinmodule.c +++ b/Python/bltinmodule.c @@ -1864,6 +1864,7 @@ _PyBuiltin_Init(void) SETBUILTIN("complex", &PyComplex_Type); #endif SETBUILTIN("dict", &PyDict_Type); + SETBUILTIN("enumerate", &PyEnum_Type); SETBUILTIN("float", &PyFloat_Type); SETBUILTIN("property", &PyProperty_Type); SETBUILTIN("int", &PyInt_Type);