diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py index bd046052a59..4ed85dfce9b 100644 --- a/Lib/test/test_descr.py +++ b/Lib/test/test_descr.py @@ -1833,7 +1833,9 @@ def str_subclass_as_dict_key(): def classic_comparisons(): if verbose: print "Testing classic comparisons..." - for base in (int, object): + class classic: + pass + for base in (classic, int, object): if verbose: print " (base = %s)" % base class C(base): def __init__(self, value): @@ -1858,6 +1860,69 @@ def classic_comparisons(): verify(cmp(c[x], y) == cmp(x, y), "x=%d, y=%d" % (x, y)) verify(cmp(x, c[y]) == cmp(x, y), "x=%d, y=%d" % (x, y)) +def rich_comparisons(): + if verbose: + print "Testing rich comparisons..." + class classic: + pass + for base in (classic, int, object, list): + if verbose: print " (base = %s)" % base + class C(base): + def __init__(self, value): + self.value = int(value) + def __cmp__(self, other): + raise TestFailed, "shouldn't call __cmp__" + def __eq__(self, other): + if isinstance(other, C): + return self.value == other.value + if isinstance(other, int) or isinstance(other, long): + return self.value == other + return NotImplemented + def __ne__(self, other): + if isinstance(other, C): + return self.value != other.value + if isinstance(other, int) or isinstance(other, long): + return self.value != other + return NotImplemented + def __lt__(self, other): + if isinstance(other, C): + return self.value < other.value + if isinstance(other, int) or isinstance(other, long): + return self.value < other + return NotImplemented + def __le__(self, other): + if isinstance(other, C): + return self.value <= other.value + if isinstance(other, int) or isinstance(other, long): + return self.value <= other + return NotImplemented + def __gt__(self, other): + if isinstance(other, C): + return self.value > other.value + if isinstance(other, int) or isinstance(other, long): + return self.value > other + return NotImplemented + def __ge__(self, other): + if isinstance(other, C): + return self.value >= other.value + if isinstance(other, int) or isinstance(other, long): + return self.value >= other + return NotImplemented + c1 = C(1) + c2 = C(2) + c3 = C(3) + verify(c1 == 1) + c = {1: c1, 2: c2, 3: c3} + for x in 1, 2, 3: + for y in 1, 2, 3: + for op in "<", "<=", "==", "!=", ">", ">=": + verify(eval("c[x] %s c[y]" % op) == eval("x %s y" % op), + "x=%d, y=%d" % (x, y)) + verify(eval("c[x] %s y" % op) == eval("x %s y" % op), + "x=%d, y=%d" % (x, y)) + verify(eval("x %s c[y]" % op) == eval("x %s y" % op), + "x=%d, y=%d" % (x, y)) + def all(): lists() @@ -1897,6 +1962,7 @@ def all(): restricted() str_subclass_as_dict_key() classic_comparisons() + rich_comparisons() all()