Minor bugs in the __index__ code (PEP 357), with tests.

This commit is contained in:
Armin Rigo 2006-03-30 14:04:02 +00:00
parent 4ef3a23a35
commit 314861c568
3 changed files with 100 additions and 132 deletions

View File

@ -10,82 +10,13 @@ class newstyle(object):
def __index__(self): def __index__(self):
return self.ind return self.ind
class ListTestCase(unittest.TestCase): class BaseTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.seq = [0,10,20,30,40,50]
self.o = oldstyle() self.o = oldstyle()
self.n = newstyle() self.n = newstyle()
self.o2 = oldstyle() self.o2 = oldstyle()
self.n2 = newstyle() self.n2 = newstyle()
def test_basic(self):
self.o.ind = -2
self.n.ind = 2
assert(self.seq[self.n] == 20)
assert(self.seq[self.o] == 40)
assert(operator.index(self.o) == -2)
assert(operator.index(self.n) == 2)
def test_error(self):
self.o.ind = 'dumb'
self.n.ind = 'bad'
myfunc = lambda x, obj: obj.seq[x]
self.failUnlessRaises(TypeError, operator.index, self.o)
self.failUnlessRaises(TypeError, operator.index, self.n)
self.failUnlessRaises(TypeError, myfunc, self.o, self)
self.failUnlessRaises(TypeError, myfunc, self.n, self)
def test_slice(self):
self.o.ind = 1
self.o2.ind = 3
self.n.ind = 2
self.n2.ind = 4
assert(self.seq[self.o:self.o2] == self.seq[1:3])
assert(self.seq[self.n:self.n2] == self.seq[2:4])
class TupleTestCase(unittest.TestCase):
def setUp(self):
self.seq = (0,10,20,30,40,50)
self.o = oldstyle()
self.n = newstyle()
self.o2 = oldstyle()
self.n2 = newstyle()
def test_basic(self):
self.o.ind = -2
self.n.ind = 2
assert(self.seq[self.n] == 20)
assert(self.seq[self.o] == 40)
assert(operator.index(self.o) == -2)
assert(operator.index(self.n) == 2)
def test_error(self):
self.o.ind = 'dumb'
self.n.ind = 'bad'
myfunc = lambda x, obj: obj.seq[x]
self.failUnlessRaises(TypeError, operator.index, self.o)
self.failUnlessRaises(TypeError, operator.index, self.n)
self.failUnlessRaises(TypeError, myfunc, self.o, self)
self.failUnlessRaises(TypeError, myfunc, self.n, self)
def test_slice(self):
self.o.ind = 1
self.o2.ind = 3
self.n.ind = 2
self.n2.ind = 4
assert(self.seq[self.o:self.o2] == self.seq[1:3])
assert(self.seq[self.n:self.n2] == self.seq[2:4])
class StringTestCase(unittest.TestCase):
def setUp(self):
self.seq = "this is a test"
self.o = oldstyle()
self.n = newstyle()
self.o2 = oldstyle()
self.n2 = newstyle()
def test_basic(self): def test_basic(self):
self.o.ind = -2 self.o.ind = -2
self.n.ind = 2 self.n.ind = 2
@ -111,40 +42,86 @@ class StringTestCase(unittest.TestCase):
assert(self.seq[self.o:self.o2] == self.seq[1:3]) assert(self.seq[self.o:self.o2] == self.seq[1:3])
assert(self.seq[self.n:self.n2] == self.seq[2:4]) assert(self.seq[self.n:self.n2] == self.seq[2:4])
def test_repeat(self):
self.o.ind = 3
self.n.ind = 2
assert(self.seq * self.o == self.seq * 3)
assert(self.seq * self.n == self.seq * 2)
assert(self.o * self.seq == self.seq * 3)
assert(self.n * self.seq == self.seq * 2)
class UnicodeTestCase(unittest.TestCase): def test_wrappers(self):
def setUp(self): n = self.n
self.seq = u"this is a test" n.ind = 5
self.o = oldstyle() assert n.__index__() == 5
self.n = newstyle() assert 6 .__index__() == 6
self.o2 = oldstyle() assert -7L.__index__() == -7
self.n2 = newstyle() assert self.seq.__getitem__(n) == self.seq[5]
assert self.seq.__mul__(n) == self.seq * 5
assert self.seq.__rmul__(n) == self.seq * 5
def test_infinite_recusion(self):
class Trap1(int):
def __index__(self):
return self
class Trap2(long):
def __index__(self):
return self
self.failUnlessRaises(TypeError, operator.getitem, self.seq, Trap1())
self.failUnlessRaises(TypeError, operator.getitem, self.seq, Trap2())
def test_basic(self): class ListTestCase(BaseTestCase):
seq = [0,10,20,30,40,50]
def test_setdelitem(self):
self.o.ind = -2 self.o.ind = -2
self.n.ind = 2 self.n.ind = 2
assert(self.seq[self.n] == self.seq[2]) lst = list('ab!cdefghi!j')
assert(self.seq[self.o] == self.seq[-2]) del lst[self.o]
assert(operator.index(self.o) == -2) del lst[self.n]
assert(operator.index(self.n) == 2) lst[self.o] = 'X'
lst[self.n] = 'Y'
assert lst == list('abYdefghXj')
def test_error(self): lst = [5, 6, 7, 8, 9, 10, 11]
self.o.ind = 'dumb' lst.__setitem__(self.n, "here")
self.n.ind = 'bad' assert lst == [5, 6, "here", 8, 9, 10, 11]
myfunc = lambda x, obj: obj.seq[x] lst.__delitem__(self.n)
self.failUnlessRaises(TypeError, operator.index, self.o) assert lst == [5, 6, 8, 9, 10, 11]
self.failUnlessRaises(TypeError, operator.index, self.n)
self.failUnlessRaises(TypeError, myfunc, self.o, self)
self.failUnlessRaises(TypeError, myfunc, self.n, self)
def test_slice(self): def test_inplace_repeat(self):
self.o.ind = 1 self.o.ind = 2
self.o2.ind = 3 self.n.ind = 3
self.n.ind = 2 lst = [6, 4]
self.n2.ind = 4 lst *= self.o
assert(self.seq[self.o:self.o2] == self.seq[1:3]) assert lst == [6, 4, 6, 4]
assert(self.seq[self.n:self.n2] == self.seq[2:4]) lst *= self.n
assert lst == [6, 4, 6, 4] * 3
lst = [5, 6, 7, 8, 9, 11]
l2 = lst.__imul__(self.n)
assert l2 is lst
assert lst == [5, 6, 7, 8, 9, 11] * 3
class TupleTestCase(BaseTestCase):
seq = (0,10,20,30,40,50)
class StringTestCase(BaseTestCase):
seq = "this is a test"
class UnicodeTestCase(BaseTestCase):
seq = u"this is a test"
class XRangeTestCase(unittest.TestCase):
def test_xrange(self):
n = newstyle()
n.ind = 5
assert xrange(1, 20)[n] == 6
assert xrange(1, 20).__getitem__(n) == 6
def test_main(): def test_main():
@ -152,7 +129,8 @@ def test_main():
ListTestCase, ListTestCase,
TupleTestCase, TupleTestCase,
StringTestCase, StringTestCase,
UnicodeTestCase UnicodeTestCase,
XRangeTestCase,
) )
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -942,8 +942,9 @@ PyNumber_Index(PyObject *item)
value = nb->nb_index(item); value = nb->nb_index(item);
} }
else { else {
PyErr_SetString(PyExc_IndexError, PyErr_Format(PyExc_TypeError,
"object cannot be interpreted as an index"); "'%.200s' object cannot be interpreted "
"as an index", item->ob_type->tp_name);
} }
return value; return value;
} }

View File

@ -3542,12 +3542,16 @@ wrap_unaryfunc(PyObject *self, PyObject *args, void *wrapped)
} }
static PyObject * static PyObject *
wrap_ssizeargfunc(PyObject *self, PyObject *args, void *wrapped) wrap_indexargfunc(PyObject *self, PyObject *args, void *wrapped)
{ {
ssizeargfunc func = (ssizeargfunc)wrapped; ssizeargfunc func = (ssizeargfunc)wrapped;
PyObject* o;
Py_ssize_t i; Py_ssize_t i;
if (!PyArg_ParseTuple(args, "n", &i)) if (!PyArg_UnpackTuple(args, "", 1, 1, &o))
return NULL;
i = PyNumber_Index(o);
if (i == -1 && PyErr_Occurred())
return NULL; return NULL;
return (*func)(self, i); return (*func)(self, i);
} }
@ -3557,7 +3561,7 @@ getindex(PyObject *self, PyObject *arg)
{ {
Py_ssize_t i; Py_ssize_t i;
i = PyInt_AsSsize_t(arg); i = PyNumber_Index(arg);
if (i == -1 && PyErr_Occurred()) if (i == -1 && PyErr_Occurred())
return -1; return -1;
if (i < 0) { if (i < 0) {
@ -4366,36 +4370,21 @@ slot_nb_nonzero(PyObject *self)
static Py_ssize_t static Py_ssize_t
slot_nb_index(PyObject *self) slot_nb_index(PyObject *self)
{ {
PyObject *func, *args;
static PyObject *index_str; static PyObject *index_str;
Py_ssize_t result = -1; PyObject *temp = call_method(self, "__index__", &index_str, "()");
Py_ssize_t result;
func = lookup_maybe(self, "__index__", &index_str); if (temp == NULL)
if (func == NULL) {
if (!PyErr_Occurred()) {
PyErr_SetString(PyExc_TypeError,
"object cannot be interpreted as an index");
}
return -1; return -1;
} if (PyInt_CheckExact(temp) || PyLong_CheckExact(temp)) {
args = PyTuple_New(0); result = temp->ob_type->tp_as_number->nb_index(temp);
if (args != NULL) {
PyObject *temp = PyObject_Call(func, args, NULL);
Py_DECREF(args);
if (temp != NULL) {
if (PyInt_Check(temp) || PyLong_Check(temp)) {
result =
temp->ob_type->tp_as_number->nb_index(temp);
}
else {
PyErr_SetString(PyExc_TypeError,
"__index__ must return an int or a long");
result = -1;
}
Py_DECREF(temp);
}
} }
Py_DECREF(func); else {
PyErr_SetString(PyExc_TypeError,
"__index__ must return an int or a long");
result = -1;
}
Py_DECREF(temp);
return result; return result;
} }
@ -5026,9 +5015,9 @@ static slotdef slotdefs[] = {
test_descr.notimplemented() */ test_descr.notimplemented() */
SQSLOT("__add__", sq_concat, NULL, wrap_binaryfunc, SQSLOT("__add__", sq_concat, NULL, wrap_binaryfunc,
"x.__add__(y) <==> x+y"), "x.__add__(y) <==> x+y"),
SQSLOT("__mul__", sq_repeat, NULL, wrap_ssizeargfunc, SQSLOT("__mul__", sq_repeat, NULL, wrap_indexargfunc,
"x.__mul__(n) <==> x*n"), "x.__mul__(n) <==> x*n"),
SQSLOT("__rmul__", sq_repeat, NULL, wrap_ssizeargfunc, SQSLOT("__rmul__", sq_repeat, NULL, wrap_indexargfunc,
"x.__rmul__(n) <==> n*x"), "x.__rmul__(n) <==> n*x"),
SQSLOT("__getitem__", sq_item, slot_sq_item, wrap_sq_item, SQSLOT("__getitem__", sq_item, slot_sq_item, wrap_sq_item,
"x.__getitem__(y) <==> x[y]"), "x.__getitem__(y) <==> x[y]"),
@ -5054,7 +5043,7 @@ static slotdef slotdefs[] = {
SQSLOT("__iadd__", sq_inplace_concat, NULL, SQSLOT("__iadd__", sq_inplace_concat, NULL,
wrap_binaryfunc, "x.__iadd__(y) <==> x+=y"), wrap_binaryfunc, "x.__iadd__(y) <==> x+=y"),
SQSLOT("__imul__", sq_inplace_repeat, NULL, SQSLOT("__imul__", sq_inplace_repeat, NULL,
wrap_ssizeargfunc, "x.__imul__(y) <==> x*=y"), wrap_indexargfunc, "x.__imul__(y) <==> x*=y"),
MPSLOT("__len__", mp_length, slot_mp_length, wrap_lenfunc, MPSLOT("__len__", mp_length, slot_mp_length, wrap_lenfunc,
"x.__len__() <==> len(x)"), "x.__len__() <==> len(x)"),