From cb8d368b824a38a0b04598ba2bcd107d6aae3595 Mon Sep 17 00:00:00 2001 From: Tim Peters Date: Sat, 5 May 2001 21:05:01 +0000 Subject: [PATCH] Reimplement PySequence_Contains() and instance_contains(), so they work safely together and don't duplicate logic (the common logic was factored out into new private API function _PySequence_IterContains()). Visible change: some_complex_number in some_instance no longer blows up if some_instance has __getitem__ but neither __contains__ nor __iter__. test_iter changed to ensure that remains true. --- Include/abstract.h | 12 +++++++- Lib/test/test_iter.py | 24 ++++------------ Objects/abstract.c | 41 +++++++++++++------------- Objects/classobject.c | 67 +++++++++++++++++++------------------------ 4 files changed, 67 insertions(+), 77 deletions(-) diff --git a/Include/abstract.h b/Include/abstract.h index d5f4a9978d4..9082edb0b83 100644 --- a/Include/abstract.h +++ b/Include/abstract.h @@ -932,7 +932,17 @@ xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx*/ expression: o.count(value). */ - DL_IMPORT(int) PySequence_Contains(PyObject *o, PyObject *value); + DL_IMPORT(int) PySequence_Contains(PyObject *seq, PyObject *ob); + /* + Return -1 if error; 1 if ob in seq; 0 if ob not in seq. + Use __contains__ if possible, else _PySequence_IterContains(). + */ + + DL_IMPORT(int) _PySequence_IterContains(PyObject *seq, PyObject *ob); + /* + Return -1 if error; 1 if ob in seq; 0 if ob not in seq. + Always uses the iteration protocol, and only Py_EQ comparisons. + */ /* For DLL-level backwards compatibility */ #undef PySequence_In diff --git a/Lib/test/test_iter.py b/Lib/test/test_iter.py index 7d15e1cfb8f..22a7c4460d4 100644 --- a/Lib/test/test_iter.py +++ b/Lib/test/test_iter.py @@ -474,24 +474,12 @@ class TestCase(unittest.TestCase): # Test iterators with 'x in y' and 'x not in y'. def test_in_and_not_in(self): - sc5 = IteratingSequenceClass(5) - for i in range(5): - self.assert_(i in sc5) - # CAUTION: This test fails on 3-12j if sc5 is SequenceClass(5) - # instead, with: - # TypeError: cannot compare complex numbers using <, <=, >, >= - # The trail leads back to instance_contains() in classobject.c, - # under comment: - # /* fall back to previous behavior */ - # IteratingSequenceClass(5) avoids the same problem only because - # it lacks __getitem__: instance_contains *tries* to do a wrong - # thing with it too, but aborts with an AttributeError the first - # time it calls instance_item(); PySequence_Contains() then catches - # that and clears it, and tries the iterator-based "contains" - # instead. But this is hanging together by a thread. - for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5: - self.assert_(i not in sc5) - del sc5 + for sc5 in IteratingSequenceClass(5), SequenceClass(5): + for i in range(5): + self.assert_(i in sc5) + for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5: + self.assert_(i not in sc5) + del sc5 self.assertRaises(TypeError, lambda: 3 in 12) self.assertRaises(TypeError, lambda: 3 not in map) diff --git a/Objects/abstract.c b/Objects/abstract.c index 21c1ef1de46..c1d77897478 100644 --- a/Objects/abstract.c +++ b/Objects/abstract.c @@ -1381,29 +1381,14 @@ Fail: return -1; } -/* Return -1 if error; 1 if v in w; 0 if v not in w. */ +/* Return -1 if error; 1 if ob in seq; 0 if ob not in seq. + * Always uses the iteration protocol, and only Py_EQ comparison. + */ int -PySequence_Contains(PyObject *w, PyObject *v) /* v in w */ +_PySequence_IterContains(PyObject *seq, PyObject *ob) { - PyObject *it; /* iter(w) */ int result; - - if (PyType_HasFeature(w->ob_type, Py_TPFLAGS_HAVE_SEQUENCE_IN)) { - PySequenceMethods *sq = w->ob_type->tp_as_sequence; - if (sq != NULL && sq->sq_contains != NULL) { - result = (*sq->sq_contains)(w, v); - if (result >= 0) - return result; - assert(PyErr_Occurred()); - if (PyErr_ExceptionMatches(PyExc_AttributeError)) - PyErr_Clear(); - else - return result; - } - } - - /* Try exhaustive iteration. */ - it = PyObject_GetIter(w); + PyObject *it = PyObject_GetIter(seq); if (it == NULL) { PyErr_SetString(PyExc_TypeError, "'in' or 'not in' needs iterable right argument"); @@ -1417,7 +1402,7 @@ PySequence_Contains(PyObject *w, PyObject *v) /* v in w */ result = PyErr_Occurred() ? -1 : 0; break; } - cmp = PyObject_RichCompareBool(v, item, Py_EQ); + cmp = PyObject_RichCompareBool(ob, item, Py_EQ); Py_DECREF(item); if (cmp == 0) continue; @@ -1428,6 +1413,20 @@ PySequence_Contains(PyObject *w, PyObject *v) /* v in w */ return result; } +/* Return -1 if error; 1 if ob in seq; 0 if ob not in seq. + * Use sq_contains if possible, else defer to _PySequence_IterContains(). + */ +int +PySequence_Contains(PyObject *seq, PyObject *ob) +{ + if (PyType_HasFeature(seq->ob_type, Py_TPFLAGS_HAVE_SEQUENCE_IN)) { + PySequenceMethods *sqm = seq->ob_type->tp_as_sequence; + if (sqm != NULL && sqm->sq_contains != NULL) + return (*sqm->sq_contains)(seq, ob); + } + return _PySequence_IterContains(seq, ob); +} + /* Backwards compatibility */ #undef PySequence_In int diff --git a/Objects/classobject.c b/Objects/classobject.c index 2babbfbd63f..67732ca2289 100644 --- a/Objects/classobject.c +++ b/Objects/classobject.c @@ -1131,11 +1131,15 @@ instance_ass_slice(PyInstanceObject *inst, int i, int j, PyObject *value) return 0; } -static int instance_contains(PyInstanceObject *inst, PyObject *member) +static int +instance_contains(PyInstanceObject *inst, PyObject *member) { static PyObject *__contains__; - PyObject *func, *arg, *res; - int ret; + PyObject *func; + + /* Try __contains__ first. + * If that can't be done, try iterator-based searching. + */ if(__contains__ == NULL) { __contains__ = PyString_InternFromString("__contains__"); @@ -1143,45 +1147,34 @@ static int instance_contains(PyInstanceObject *inst, PyObject *member) return -1; } func = instance_getattr(inst, __contains__); - if(func == NULL) { - /* fall back to previous behavior */ - int i, cmp_res; - - if(!PyErr_ExceptionMatches(PyExc_AttributeError)) + if (func) { + PyObject *res; + int ret; + PyObject *arg = Py_BuildValue("(O)", member); + if(arg == NULL) { + Py_DECREF(func); return -1; - PyErr_Clear(); - for(i=0;;i++) { - PyObject *obj = instance_item(inst, i); - int ret = 0; - - if(obj == NULL) { - if(!PyErr_ExceptionMatches(PyExc_IndexError)) - return -1; - PyErr_Clear(); - return 0; - } - if(PyObject_Cmp(obj, member, &cmp_res) == -1) - ret = -1; - if(cmp_res == 0) - ret = 1; - Py_DECREF(obj); - if(ret) - return ret; } - } - arg = Py_BuildValue("(O)", member); - if(arg == NULL) { + res = PyEval_CallObject(func, arg); Py_DECREF(func); - return -1; + Py_DECREF(arg); + if(res == NULL) + return -1; + ret = PyObject_IsTrue(res); + Py_DECREF(res); + return ret; } - res = PyEval_CallObject(func, arg); - Py_DECREF(func); - Py_DECREF(arg); - if(res == NULL) + + /* Couldn't find __contains__. */ + if (PyErr_ExceptionMatches(PyExc_AttributeError)) { + /* Assume the failure was simply due to that there is no + * __contains__ attribute, and try iterating instead. + */ + PyErr_Clear(); + return _PySequence_IterContains((PyObject *)inst, member); + } + else return -1; - ret = PyObject_IsTrue(res); - Py_DECREF(res); - return ret; } static PySequenceMethods