Do the same thing to complex that I did to str: the rich comparison

function returns NotImplemented when comparing objects whose
tp_richcompare slot is not itself.
This commit is contained in:
Guido van Rossum 2001-09-24 17:52:04 +00:00
parent e47df7a211
commit 2205642fe0
2 changed files with 25 additions and 7 deletions

View File

@ -1863,6 +1863,21 @@ def classic_comparisons():
def rich_comparisons(): def rich_comparisons():
if verbose: if verbose:
print "Testing rich comparisons..." print "Testing rich comparisons..."
class Z(complex):
pass
z = Z(1)
verify(z == 1+0j)
verify(1+0j == z)
class ZZ(complex):
def __eq__(self, other):
try:
return abs(self - other) <= 1e-6
except:
return NotImplemented
zz = ZZ(1.0000003)
verify(zz == 1+0j)
verify(1+0j == zz)
class classic: class classic:
pass pass
for base in (classic, int, object, list): for base in (classic, int, object, list):

View File

@ -553,12 +553,6 @@ complex_richcompare(PyObject *v, PyObject *w, int op)
Py_complex i, j; Py_complex i, j;
PyObject *res; PyObject *res;
if (op != Py_EQ && op != Py_NE) {
PyErr_SetString(PyExc_TypeError,
"cannot compare complex numbers using <, <=, >, >=");
return NULL;
}
c = PyNumber_CoerceEx(&v, &w); c = PyNumber_CoerceEx(&v, &w);
if (c < 0) if (c < 0)
return NULL; return NULL;
@ -566,7 +560,10 @@ complex_richcompare(PyObject *v, PyObject *w, int op)
Py_INCREF(Py_NotImplemented); Py_INCREF(Py_NotImplemented);
return Py_NotImplemented; return Py_NotImplemented;
} }
if (!PyComplex_Check(v) || !PyComplex_Check(w)) { /* May sure both arguments use complex comparison.
This implies PyComplex_Check(a) && PyComplex_Check(b). */
if (v->ob_type->tp_richcompare != complex_richcompare ||
w->ob_type->tp_richcompare != complex_richcompare) {
Py_DECREF(v); Py_DECREF(v);
Py_DECREF(w); Py_DECREF(w);
Py_INCREF(Py_NotImplemented); Py_INCREF(Py_NotImplemented);
@ -578,6 +575,12 @@ complex_richcompare(PyObject *v, PyObject *w, int op)
Py_DECREF(v); Py_DECREF(v);
Py_DECREF(w); Py_DECREF(w);
if (op != Py_EQ && op != Py_NE) {
PyErr_SetString(PyExc_TypeError,
"cannot compare complex numbers using <, <=, >, >=");
return NULL;
}
if ((i.real == j.real && i.imag == j.imag) == (op == Py_EQ)) if ((i.real == j.real && i.imag == j.imag) == (op == Py_EQ))
res = Py_True; res = Py_True;
else else