"""Test equality and order comparisons.""" import unittest from test.support import ALWAYS_EQ from fractions import Fraction from decimal import Decimal class ComparisonSimpleTest(unittest.TestCase): """Test equality and order comparisons for some simple cases.""" class Empty: def __repr__(self): return '' class Cmp: def __init__(self, arg): self.arg = arg def __repr__(self): return '' % self.arg def __eq__(self, other): return self.arg == other set1 = [2, 2.0, 2, 2+0j, Cmp(2.0)] set2 = [[1], (3,), None, Empty()] candidates = set1 + set2 def test_comparisons(self): for a in self.candidates: for b in self.candidates: if ((a in self.set1) and (b in self.set1)) or a is b: self.assertEqual(a, b) else: self.assertNotEqual(a, b) def test_id_comparisons(self): # Ensure default comparison compares id() of args L = [] for i in range(10): L.insert(len(L)//2, self.Empty()) for a in L: for b in L: self.assertEqual(a == b, a is b, 'a=%r, b=%r' % (a, b)) def test_ne_defaults_to_not_eq(self): a = self.Cmp(1) b = self.Cmp(1) c = self.Cmp(2) self.assertIs(a == b, True) self.assertIs(a != b, False) self.assertIs(a != c, True) def test_ne_high_priority(self): """object.__ne__() should allow reflected __ne__() to be tried""" calls = [] class Left: # Inherits object.__ne__() def __eq__(*args): calls.append('Left.__eq__') return NotImplemented class Right: def __eq__(*args): calls.append('Right.__eq__') return NotImplemented def __ne__(*args): calls.append('Right.__ne__') return NotImplemented Left() != Right() self.assertSequenceEqual(calls, ['Left.__eq__', 'Right.__ne__']) def test_ne_low_priority(self): """object.__ne__() should not invoke reflected __eq__()""" calls = [] class Base: # Inherits object.__ne__() def __eq__(*args): calls.append('Base.__eq__') return NotImplemented class Derived(Base): # Subclassing forces higher priority def __eq__(*args): calls.append('Derived.__eq__') return NotImplemented def __ne__(*args): calls.append('Derived.__ne__') return NotImplemented Base() != Derived() self.assertSequenceEqual(calls, ['Derived.__ne__', 'Base.__eq__']) def test_other_delegation(self): """No default delegation between operations except __ne__()""" ops = ( ('__eq__', lambda a, b: a == b), ('__lt__', lambda a, b: a < b), ('__le__', lambda a, b: a <= b), ('__gt__', lambda a, b: a > b), ('__ge__', lambda a, b: a >= b), ) for name, func in ops: with self.subTest(name): def unexpected(*args): self.fail('Unexpected operator method called') class C: __ne__ = unexpected for other, _ in ops: if other != name: setattr(C, other, unexpected) if name == '__eq__': self.assertIs(func(C(), object()), False) else: self.assertRaises(TypeError, func, C(), object()) def test_issue_1393(self): x = lambda: None self.assertEqual(x, ALWAYS_EQ) self.assertEqual(ALWAYS_EQ, x) y = object() self.assertEqual(y, ALWAYS_EQ) self.assertEqual(ALWAYS_EQ, y) class ComparisonFullTest(unittest.TestCase): """Test equality and ordering comparisons for built-in types and user-defined classes that implement relevant combinations of rich comparison methods. """ class CompBase: """Base class for classes with rich comparison methods. The "x" attribute should be set to an underlying value to compare. Derived classes have a "meth" tuple attribute listing names of comparison methods implemented. See assert_total_order(). """ # Class without any rich comparison methods. class CompNone(CompBase): meth = () # Classes with all combinations of value-based equality comparison methods. class CompEq(CompBase): meth = ("eq",) def __eq__(self, other): return self.x == other.x class CompNe(CompBase): meth = ("ne",) def __ne__(self, other): return self.x != other.x class CompEqNe(CompBase): meth = ("eq", "ne") def __eq__(self, other): return self.x == other.x def __ne__(self, other): return self.x != other.x # Classes with all combinations of value-based less/greater-than order # comparison methods. class CompLt(CompBase): meth = ("lt",) def __lt__(self, other): return self.x < other.x class CompGt(CompBase): meth = ("gt",) def __gt__(self, other): return self.x > other.x class CompLtGt(CompBase): meth = ("lt", "gt") def __lt__(self, other): return self.x < other.x def __gt__(self, other): return self.x > other.x # Classes with all combinations of value-based less/greater-or-equal-than # order comparison methods class CompLe(CompBase): meth = ("le",) def __le__(self, other): return self.x <= other.x class CompGe(CompBase): meth = ("ge",) def __ge__(self, other): return self.x >= other.x class CompLeGe(CompBase): meth = ("le", "ge") def __le__(self, other): return self.x <= other.x def __ge__(self, other): return self.x >= other.x # It should be sufficient to combine the comparison methods only within # each group. all_comp_classes = ( CompNone, CompEq, CompNe, CompEqNe, # equal group CompLt, CompGt, CompLtGt, # less/greater-than group CompLe, CompGe, CompLeGe) # less/greater-or-equal group def create_sorted_instances(self, class_, values): """Create objects of type `class_` and return them in a list. `values` is a list of values that determines the value of data attribute `x` of each object. Objects in the returned list are sorted by their identity. They assigned values in `values` list order. By assign decreasing values to objects with increasing identities, testcases can assert that order comparison is performed by value and not by identity. """ instances = [class_() for __ in range(len(values))] instances.sort(key=id) # Assign the provided values to the instances. for inst, value in zip(instances, values): inst.x = value return instances def assert_equality_only(self, a, b, equal): """Assert equality result and that ordering is not implemented. a, b: Instances to be tested (of same or different type). equal: Boolean indicating the expected equality comparison results. """ self.assertEqual(a == b, equal) self.assertEqual(b == a, equal) self.assertEqual(a != b, not equal) self.assertEqual(b != a, not equal) with self.assertRaisesRegex(TypeError, "not supported"): a < b with self.assertRaisesRegex(TypeError, "not supported"): a <= b with self.assertRaisesRegex(TypeError, "not supported"): a > b with self.assertRaisesRegex(TypeError, "not supported"): a >= b with self.assertRaisesRegex(TypeError, "not supported"): b < a with self.assertRaisesRegex(TypeError, "not supported"): b <= a with self.assertRaisesRegex(TypeError, "not supported"): b > a with self.assertRaisesRegex(TypeError, "not supported"): b >= a def assert_total_order(self, a, b, comp, a_meth=None, b_meth=None): """Test total ordering comparison of two instances. a, b: Instances to be tested (of same or different type). comp: -1, 0, or 1 indicates that the expected order comparison result for operations that are supported by the classes is a <, ==, or > b. a_meth, b_meth: Either None, indicating that all rich comparison methods are available, aa for builtins, or the tuple (subset) of "eq", "ne", "lt", "le", "gt", and "ge" that are available for the corresponding instance (of a user-defined class). """ self.assert_eq_subtest(a, b, comp, a_meth, b_meth) self.assert_ne_subtest(a, b, comp, a_meth, b_meth) self.assert_lt_subtest(a, b, comp, a_meth, b_meth) self.assert_le_subtest(a, b, comp, a_meth, b_meth) self.assert_gt_subtest(a, b, comp, a_meth, b_meth) self.assert_ge_subtest(a, b, comp, a_meth, b_meth) # The body of each subtest has form: # # if value-based comparison methods: # expect what the testcase defined for a op b and b rop a; # else: no value-based comparison # expect default behavior of object for a op b and b rop a. def assert_eq_subtest(self, a, b, comp, a_meth, b_meth): if a_meth is None or "eq" in a_meth or "eq" in b_meth: self.assertEqual(a == b, comp == 0) self.assertEqual(b == a, comp == 0) else: self.assertEqual(a == b, a is b) self.assertEqual(b == a, a is b) def assert_ne_subtest(self, a, b, comp, a_meth, b_meth): if a_meth is None or not {"ne", "eq"}.isdisjoint(a_meth + b_meth): self.assertEqual(a != b, comp != 0) self.assertEqual(b != a, comp != 0) else: self.assertEqual(a != b, a is not b) self.assertEqual(b != a, a is not b) def assert_lt_subtest(self, a, b, comp, a_meth, b_meth): if a_meth is None or "lt" in a_meth or "gt" in b_meth: self.assertEqual(a < b, comp < 0) self.assertEqual(b > a, comp < 0) else: with self.assertRaisesRegex(TypeError, "not supported"): a < b with self.assertRaisesRegex(TypeError, "not supported"): b > a def assert_le_subtest(self, a, b, comp, a_meth, b_meth): if a_meth is None or "le" in a_meth or "ge" in b_meth: self.assertEqual(a <= b, comp <= 0) self.assertEqual(b >= a, comp <= 0) else: with self.assertRaisesRegex(TypeError, "not supported"): a <= b with self.assertRaisesRegex(TypeError, "not supported"): b >= a def assert_gt_subtest(self, a, b, comp, a_meth, b_meth): if a_meth is None or "gt" in a_meth or "lt" in b_meth: self.assertEqual(a > b, comp > 0) self.assertEqual(b < a, comp > 0) else: with self.assertRaisesRegex(TypeError, "not supported"): a > b with self.assertRaisesRegex(TypeError, "not supported"): b < a def assert_ge_subtest(self, a, b, comp, a_meth, b_meth): if a_meth is None or "ge" in a_meth or "le" in b_meth: self.assertEqual(a >= b, comp >= 0) self.assertEqual(b <= a, comp >= 0) else: with self.assertRaisesRegex(TypeError, "not supported"): a >= b with self.assertRaisesRegex(TypeError, "not supported"): b <= a def test_objects(self): """Compare instances of type 'object'.""" a = object() b = object() self.assert_equality_only(a, a, True) self.assert_equality_only(a, b, False) def test_comp_classes_same(self): """Compare same-class instances with comparison methods.""" for cls in self.all_comp_classes: with self.subTest(cls): instances = self.create_sorted_instances(cls, (1, 2, 1)) # Same object. self.assert_total_order(instances[0], instances[0], 0, cls.meth, cls.meth) # Different objects, same value. self.assert_total_order(instances[0], instances[2], 0, cls.meth, cls.meth) # Different objects, value ascending for ascending identities. self.assert_total_order(instances[0], instances[1], -1, cls.meth, cls.meth) # different objects, value descending for ascending identities. # This is the interesting case to assert that order comparison # is performed based on the value and not based on the identity. self.assert_total_order(instances[1], instances[2], +1, cls.meth, cls.meth) def test_comp_classes_different(self): """Compare different-class instances with comparison methods.""" for cls_a in self.all_comp_classes: for cls_b in self.all_comp_classes: with self.subTest(a=cls_a, b=cls_b): a1 = cls_a() a1.x = 1 b1 = cls_b() b1.x = 1 b2 = cls_b() b2.x = 2 self.assert_total_order( a1, b1, 0, cls_a.meth, cls_b.meth) self.assert_total_order( a1, b2, -1, cls_a.meth, cls_b.meth) def test_str_subclass(self): """Compare instances of str and a subclass.""" class StrSubclass(str): pass s1 = str("a") s2 = str("b") c1 = StrSubclass("a") c2 = StrSubclass("b") c3 = StrSubclass("b") self.assert_total_order(s1, s1, 0) self.assert_total_order(s1, s2, -1) self.assert_total_order(c1, c1, 0) self.assert_total_order(c1, c2, -1) self.assert_total_order(c2, c3, 0) self.assert_total_order(s1, c2, -1) self.assert_total_order(s2, c3, 0) self.assert_total_order(c1, s2, -1) self.assert_total_order(c2, s2, 0) def test_numbers(self): """Compare number types.""" # Same types. i1 = 1001 i2 = 1002 self.assert_total_order(i1, i1, 0) self.assert_total_order(i1, i2, -1) f1 = 1001.0 f2 = 1001.1 self.assert_total_order(f1, f1, 0) self.assert_total_order(f1, f2, -1) q1 = Fraction(2002, 2) q2 = Fraction(2003, 2) self.assert_total_order(q1, q1, 0) self.assert_total_order(q1, q2, -1) d1 = Decimal('1001.0') d2 = Decimal('1001.1') self.assert_total_order(d1, d1, 0) self.assert_total_order(d1, d2, -1) c1 = 1001+0j c2 = 1001+1j self.assert_equality_only(c1, c1, True) self.assert_equality_only(c1, c2, False) # Mixing types. for n1, n2 in ((i1,f1), (i1,q1), (i1,d1), (f1,q1), (f1,d1), (q1,d1)): self.assert_total_order(n1, n2, 0) for n1 in (i1, f1, q1, d1): self.assert_equality_only(n1, c1, True) def test_sequences(self): """Compare list, tuple, and range.""" l1 = [1, 2] l2 = [2, 3] self.assert_total_order(l1, l1, 0) self.assert_total_order(l1, l2, -1) t1 = (1, 2) t2 = (2, 3) self.assert_total_order(t1, t1, 0) self.assert_total_order(t1, t2, -1) r1 = range(1, 2) r2 = range(2, 2) self.assert_equality_only(r1, r1, True) self.assert_equality_only(r1, r2, False) self.assert_equality_only(t1, l1, False) self.assert_equality_only(l1, r1, False) self.assert_equality_only(r1, t1, False) def test_bytes(self): """Compare bytes and bytearray.""" bs1 = b'a1' bs2 = b'b2' self.assert_total_order(bs1, bs1, 0) self.assert_total_order(bs1, bs2, -1) ba1 = bytearray(b'a1') ba2 = bytearray(b'b2') self.assert_total_order(ba1, ba1, 0) self.assert_total_order(ba1, ba2, -1) self.assert_total_order(bs1, ba1, 0) self.assert_total_order(bs1, ba2, -1) self.assert_total_order(ba1, bs1, 0) self.assert_total_order(ba1, bs2, -1) def test_sets(self): """Compare set and frozenset.""" s1 = {1, 2} s2 = {1, 2, 3} self.assert_total_order(s1, s1, 0) self.assert_total_order(s1, s2, -1) f1 = frozenset(s1) f2 = frozenset(s2) self.assert_total_order(f1, f1, 0) self.assert_total_order(f1, f2, -1) self.assert_total_order(s1, f1, 0) self.assert_total_order(s1, f2, -1) self.assert_total_order(f1, s1, 0) self.assert_total_order(f1, s2, -1) def test_mappings(self): """ Compare dict. """ d1 = {1: "a", 2: "b"} d2 = {2: "b", 3: "c"} d3 = {3: "c", 2: "b"} self.assert_equality_only(d1, d1, True) self.assert_equality_only(d1, d2, False) self.assert_equality_only(d2, d3, True) if __name__ == '__main__': unittest.main()