From ab3b0343b89b4683148dadaf89728ee1198ebee5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 18 Sep 2001 20:38:53 +0000 Subject: [PATCH] Hopefully fix 3-way comparisons. This unfortunately adds yet another hack, and it's even more disgusting than a PyInstance_Check() call. If the tp_compare slot is the slot used for overrides in Python, it's always called. Add some tests that show what should work too. --- Include/object.h | 4 ++++ Lib/test/test_descr.py | 28 ++++++++++++++++++++++++++++ Objects/object.c | 16 +++++++++++++++- Objects/typeobject.c | 11 ++++++----- 4 files changed, 53 insertions(+), 6 deletions(-) diff --git a/Include/object.h b/Include/object.h index d9c35144b32..160331ed23c 100644 --- a/Include/object.h +++ b/Include/object.h @@ -346,6 +346,10 @@ extern DL_IMPORT(int) PyNumber_CoerceEx(PyObject **, PyObject **); extern DL_IMPORT(void) (*PyObject_ClearWeakRefs)(PyObject *); +/* A slot function whose address we need to compare */ +extern int _PyObject_SlotCompare(PyObject *, PyObject *); + + /* PyObject_Dir(obj) acts like Python __builtin__.dir(obj), returning a list of strings. PyObject_Dir(NULL) is like __builtin__.dir(), returning the names of the current locals. In this case, if there are diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py index fc003186737..bd046052a59 100644 --- a/Lib/test/test_descr.py +++ b/Lib/test/test_descr.py @@ -1831,6 +1831,33 @@ def str_subclass_as_dict_key(): verify(cistr('ONe') in d) verify(d.get(cistr('thrEE')) == 3) +def classic_comparisons(): + if verbose: print "Testing classic comparisons..." + for base in (int, object): + if verbose: print " (base = %s)" % base + class C(base): + def __init__(self, value): + self.value = int(value) + def __cmp__(self, other): + if isinstance(other, C): + return cmp(self.value, other.value) + if isinstance(other, int) or isinstance(other, long): + return cmp(self.value, other) + return NotImplemented + c1 = C(1) + c2 = C(2) + c3 = C(3) + verify(c1 == 1) + c = {1: c1, 2: c2, 3: c3} + for x in 1, 2, 3: + for y in 1, 2, 3: + verify(cmp(c[x], c[y]) == cmp(x, y), "x=%d, y=%d" % (x, y)) + for op in "<", "<=", "==", "!=", ">", ">=": + verify(eval("c[x] %s c[y]" % op) == eval("x %s y" % op), + "x=%d, y=%d" % (x, y)) + verify(cmp(c[x], y) == cmp(x, y), "x=%d, y=%d" % (x, y)) + verify(cmp(x, c[y]) == cmp(x, y), "x=%d, y=%d" % (x, y)) + def all(): lists() @@ -1869,6 +1896,7 @@ def all(): keywords() restricted() str_subclass_as_dict_key() + classic_comparisons() all() diff --git a/Objects/object.c b/Objects/object.c index c56c3be9175..668bd4f3324 100644 --- a/Objects/object.c +++ b/Objects/object.c @@ -455,11 +455,25 @@ try_3way_compare(PyObject *v, PyObject *w) /* Comparisons involving instances are given to instance_compare, which has the same return conventions as this function. */ + f = v->ob_type->tp_compare; if (PyInstance_Check(v)) - return (*v->ob_type->tp_compare)(v, w); + return (*f)(v, w); if (PyInstance_Check(w)) return (*w->ob_type->tp_compare)(v, w); + /* If both have the same (non-NULL) tp_compare, use it. */ + if (f != NULL && f == w->ob_type->tp_compare) { + c = (*f)(v, w); + if (c < 0 && PyErr_Occurred()) + return -1; + return c < 0 ? -1 : c > 0 ? 1 : 0; + } + + /* If either tp_compare is _PyObject_SlotCompare, that's safe. */ + if (f == _PyObject_SlotCompare || + w->ob_type->tp_compare == _PyObject_SlotCompare) + return _PyObject_SlotCompare(v, w); + /* Try coercion; if it fails, give up */ c = PyNumber_CoerceEx(&v, &w); if (c < 0) diff --git a/Objects/typeobject.c b/Objects/typeobject.c index 792a9f3c286..26ddabe0c26 100644 --- a/Objects/typeobject.c +++ b/Objects/typeobject.c @@ -2761,17 +2761,18 @@ half_compare(PyObject *self, PyObject *other) return 2; } -static int -slot_tp_compare(PyObject *self, PyObject *other) +/* This slot is published for the benefit of try_3way_compare in object.c */ +int +_PyObject_SlotCompare(PyObject *self, PyObject *other) { int c; - if (self->ob_type->tp_compare == slot_tp_compare) { + if (self->ob_type->tp_compare == _PyObject_SlotCompare) { c = half_compare(self, other); if (c <= 1) return c; } - if (other->ob_type->tp_compare == slot_tp_compare) { + if (other->ob_type->tp_compare == _PyObject_SlotCompare) { c = half_compare(other, self); if (c < -1) return -2; @@ -3190,7 +3191,7 @@ override_slots(PyTypeObject *type, PyObject *dict) PyDict_GetItemString(dict, "__repr__")) type->tp_print = NULL; - TPSLOT("__cmp__", tp_compare, slot_tp_compare); + TPSLOT("__cmp__", tp_compare, _PyObject_SlotCompare); TPSLOT("__repr__", tp_repr, slot_tp_repr); TPSLOT("__hash__", tp_hash, slot_tp_hash); TPSLOT("__call__", tp_call, slot_tp_call);