gh-56276: Add tests to test_compare (#3199)

Co-authored-by: Terry Jan Reedy <tjreedy@udel.edu>
Co-authored-by: Oleg Iarygin <oleg@arhadthedev.net>
This commit is contained in:
Cheryl Sabella 2023-05-20 12:07:40 -04:00 committed by GitHub
parent 2c97878bb8
commit 68ee8b3f15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 409 additions and 17 deletions

View File

@ -1,21 +1,27 @@
"""Test equality and order comparisons."""
import unittest
from test.support import ALWAYS_EQ
from fractions import Fraction
from decimal import Decimal
class Empty:
def __repr__(self):
return '<Empty>'
class Cmp:
def __init__(self,arg):
self.arg = arg
class ComparisonSimpleTest(unittest.TestCase):
"""Test equality and order comparisons for some simple cases."""
def __repr__(self):
return '<Cmp %s>' % self.arg
class Empty:
def __repr__(self):
return '<Empty>'
def __eq__(self, other):
return self.arg == other
class Cmp:
def __init__(self, arg):
self.arg = arg
def __repr__(self):
return '<Cmp %s>' % self.arg
def __eq__(self, other):
return self.arg == other
class ComparisonTest(unittest.TestCase):
set1 = [2, 2.0, 2, 2+0j, Cmp(2.0)]
set2 = [[1], (3,), None, Empty()]
candidates = set1 + set2
@ -32,16 +38,15 @@ class ComparisonTest(unittest.TestCase):
# Ensure default comparison compares id() of args
L = []
for i in range(10):
L.insert(len(L)//2, Empty())
L.insert(len(L)//2, self.Empty())
for a in L:
for b in L:
self.assertEqual(a == b, id(a) == id(b),
'a=%r, b=%r' % (a, b))
self.assertEqual(a == b, a is b, 'a=%r, b=%r' % (a, b))
def test_ne_defaults_to_not_eq(self):
a = Cmp(1)
b = Cmp(1)
c = Cmp(2)
a = self.Cmp(1)
b = self.Cmp(1)
c = self.Cmp(2)
self.assertIs(a == b, True)
self.assertIs(a != b, False)
self.assertIs(a != c, True)
@ -114,5 +119,392 @@ class ComparisonTest(unittest.TestCase):
self.assertEqual(ALWAYS_EQ, y)
class ComparisonFullTest(unittest.TestCase):
"""Test equality and ordering comparisons for built-in types and
user-defined classes that implement relevant combinations of rich
comparison methods.
"""
class CompBase:
"""Base class for classes with rich comparison methods.
The "x" attribute should be set to an underlying value to compare.
Derived classes have a "meth" tuple attribute listing names of
comparison methods implemented. See assert_total_order().
"""
# Class without any rich comparison methods.
class CompNone(CompBase):
meth = ()
# Classes with all combinations of value-based equality comparison methods.
class CompEq(CompBase):
meth = ("eq",)
def __eq__(self, other):
return self.x == other.x
class CompNe(CompBase):
meth = ("ne",)
def __ne__(self, other):
return self.x != other.x
class CompEqNe(CompBase):
meth = ("eq", "ne")
def __eq__(self, other):
return self.x == other.x
def __ne__(self, other):
return self.x != other.x
# Classes with all combinations of value-based less/greater-than order
# comparison methods.
class CompLt(CompBase):
meth = ("lt",)
def __lt__(self, other):
return self.x < other.x
class CompGt(CompBase):
meth = ("gt",)
def __gt__(self, other):
return self.x > other.x
class CompLtGt(CompBase):
meth = ("lt", "gt")
def __lt__(self, other):
return self.x < other.x
def __gt__(self, other):
return self.x > other.x
# Classes with all combinations of value-based less/greater-or-equal-than
# order comparison methods
class CompLe(CompBase):
meth = ("le",)
def __le__(self, other):
return self.x <= other.x
class CompGe(CompBase):
meth = ("ge",)
def __ge__(self, other):
return self.x >= other.x
class CompLeGe(CompBase):
meth = ("le", "ge")
def __le__(self, other):
return self.x <= other.x
def __ge__(self, other):
return self.x >= other.x
# It should be sufficient to combine the comparison methods only within
# each group.
all_comp_classes = (
CompNone,
CompEq, CompNe, CompEqNe, # equal group
CompLt, CompGt, CompLtGt, # less/greater-than group
CompLe, CompGe, CompLeGe) # less/greater-or-equal group
def create_sorted_instances(self, class_, values):
"""Create objects of type `class_` and return them in a list.
`values` is a list of values that determines the value of data
attribute `x` of each object.
Objects in the returned list are sorted by their identity. They
assigned values in `values` list order. By assign decreasing
values to objects with increasing identities, testcases can assert
that order comparison is performed by value and not by identity.
"""
instances = [class_() for __ in range(len(values))]
instances.sort(key=id)
# Assign the provided values to the instances.
for inst, value in zip(instances, values):
inst.x = value
return instances
def assert_equality_only(self, a, b, equal):
"""Assert equality result and that ordering is not implemented.
a, b: Instances to be tested (of same or different type).
equal: Boolean indicating the expected equality comparison results.
"""
self.assertEqual(a == b, equal)
self.assertEqual(b == a, equal)
self.assertEqual(a != b, not equal)
self.assertEqual(b != a, not equal)
with self.assertRaisesRegex(TypeError, "not supported"):
a < b
with self.assertRaisesRegex(TypeError, "not supported"):
a <= b
with self.assertRaisesRegex(TypeError, "not supported"):
a > b
with self.assertRaisesRegex(TypeError, "not supported"):
a >= b
with self.assertRaisesRegex(TypeError, "not supported"):
b < a
with self.assertRaisesRegex(TypeError, "not supported"):
b <= a
with self.assertRaisesRegex(TypeError, "not supported"):
b > a
with self.assertRaisesRegex(TypeError, "not supported"):
b >= a
def assert_total_order(self, a, b, comp, a_meth=None, b_meth=None):
"""Test total ordering comparison of two instances.
a, b: Instances to be tested (of same or different type).
comp: -1, 0, or 1 indicates that the expected order comparison
result for operations that are supported by the classes is
a <, ==, or > b.
a_meth, b_meth: Either None, indicating that all rich comparison
methods are available, aa for builtins, or the tuple (subset)
of "eq", "ne", "lt", "le", "gt", and "ge" that are available
for the corresponding instance (of a user-defined class).
"""
self.assert_eq_subtest(a, b, comp, a_meth, b_meth)
self.assert_ne_subtest(a, b, comp, a_meth, b_meth)
self.assert_lt_subtest(a, b, comp, a_meth, b_meth)
self.assert_le_subtest(a, b, comp, a_meth, b_meth)
self.assert_gt_subtest(a, b, comp, a_meth, b_meth)
self.assert_ge_subtest(a, b, comp, a_meth, b_meth)
# The body of each subtest has form:
#
# if value-based comparison methods:
# expect what the testcase defined for a op b and b rop a;
# else: no value-based comparison
# expect default behavior of object for a op b and b rop a.
def assert_eq_subtest(self, a, b, comp, a_meth, b_meth):
if a_meth is None or "eq" in a_meth or "eq" in b_meth:
self.assertEqual(a == b, comp == 0)
self.assertEqual(b == a, comp == 0)
else:
self.assertEqual(a == b, a is b)
self.assertEqual(b == a, a is b)
def assert_ne_subtest(self, a, b, comp, a_meth, b_meth):
if a_meth is None or not {"ne", "eq"}.isdisjoint(a_meth + b_meth):
self.assertEqual(a != b, comp != 0)
self.assertEqual(b != a, comp != 0)
else:
self.assertEqual(a != b, a is not b)
self.assertEqual(b != a, a is not b)
def assert_lt_subtest(self, a, b, comp, a_meth, b_meth):
if a_meth is None or "lt" in a_meth or "gt" in b_meth:
self.assertEqual(a < b, comp < 0)
self.assertEqual(b > a, comp < 0)
else:
with self.assertRaisesRegex(TypeError, "not supported"):
a < b
with self.assertRaisesRegex(TypeError, "not supported"):
b > a
def assert_le_subtest(self, a, b, comp, a_meth, b_meth):
if a_meth is None or "le" in a_meth or "ge" in b_meth:
self.assertEqual(a <= b, comp <= 0)
self.assertEqual(b >= a, comp <= 0)
else:
with self.assertRaisesRegex(TypeError, "not supported"):
a <= b
with self.assertRaisesRegex(TypeError, "not supported"):
b >= a
def assert_gt_subtest(self, a, b, comp, a_meth, b_meth):
if a_meth is None or "gt" in a_meth or "lt" in b_meth:
self.assertEqual(a > b, comp > 0)
self.assertEqual(b < a, comp > 0)
else:
with self.assertRaisesRegex(TypeError, "not supported"):
a > b
with self.assertRaisesRegex(TypeError, "not supported"):
b < a
def assert_ge_subtest(self, a, b, comp, a_meth, b_meth):
if a_meth is None or "ge" in a_meth or "le" in b_meth:
self.assertEqual(a >= b, comp >= 0)
self.assertEqual(b <= a, comp >= 0)
else:
with self.assertRaisesRegex(TypeError, "not supported"):
a >= b
with self.assertRaisesRegex(TypeError, "not supported"):
b <= a
def test_objects(self):
"""Compare instances of type 'object'."""
a = object()
b = object()
self.assert_equality_only(a, a, True)
self.assert_equality_only(a, b, False)
def test_comp_classes_same(self):
"""Compare same-class instances with comparison methods."""
for cls in self.all_comp_classes:
with self.subTest(cls):
instances = self.create_sorted_instances(cls, (1, 2, 1))
# Same object.
self.assert_total_order(instances[0], instances[0], 0,
cls.meth, cls.meth)
# Different objects, same value.
self.assert_total_order(instances[0], instances[2], 0,
cls.meth, cls.meth)
# Different objects, value ascending for ascending identities.
self.assert_total_order(instances[0], instances[1], -1,
cls.meth, cls.meth)
# different objects, value descending for ascending identities.
# This is the interesting case to assert that order comparison
# is performed based on the value and not based on the identity.
self.assert_total_order(instances[1], instances[2], +1,
cls.meth, cls.meth)
def test_comp_classes_different(self):
"""Compare different-class instances with comparison methods."""
for cls_a in self.all_comp_classes:
for cls_b in self.all_comp_classes:
with self.subTest(a=cls_a, b=cls_b):
a1 = cls_a()
a1.x = 1
b1 = cls_b()
b1.x = 1
b2 = cls_b()
b2.x = 2
self.assert_total_order(
a1, b1, 0, cls_a.meth, cls_b.meth)
self.assert_total_order(
a1, b2, -1, cls_a.meth, cls_b.meth)
def test_str_subclass(self):
"""Compare instances of str and a subclass."""
class StrSubclass(str):
pass
s1 = str("a")
s2 = str("b")
c1 = StrSubclass("a")
c2 = StrSubclass("b")
c3 = StrSubclass("b")
self.assert_total_order(s1, s1, 0)
self.assert_total_order(s1, s2, -1)
self.assert_total_order(c1, c1, 0)
self.assert_total_order(c1, c2, -1)
self.assert_total_order(c2, c3, 0)
self.assert_total_order(s1, c2, -1)
self.assert_total_order(s2, c3, 0)
self.assert_total_order(c1, s2, -1)
self.assert_total_order(c2, s2, 0)
def test_numbers(self):
"""Compare number types."""
# Same types.
i1 = 1001
i2 = 1002
self.assert_total_order(i1, i1, 0)
self.assert_total_order(i1, i2, -1)
f1 = 1001.0
f2 = 1001.1
self.assert_total_order(f1, f1, 0)
self.assert_total_order(f1, f2, -1)
q1 = Fraction(2002, 2)
q2 = Fraction(2003, 2)
self.assert_total_order(q1, q1, 0)
self.assert_total_order(q1, q2, -1)
d1 = Decimal('1001.0')
d2 = Decimal('1001.1')
self.assert_total_order(d1, d1, 0)
self.assert_total_order(d1, d2, -1)
c1 = 1001+0j
c2 = 1001+1j
self.assert_equality_only(c1, c1, True)
self.assert_equality_only(c1, c2, False)
# Mixing types.
for n1, n2 in ((i1,f1), (i1,q1), (i1,d1), (f1,q1), (f1,d1), (q1,d1)):
self.assert_total_order(n1, n2, 0)
for n1 in (i1, f1, q1, d1):
self.assert_equality_only(n1, c1, True)
def test_sequences(self):
"""Compare list, tuple, and range."""
l1 = [1, 2]
l2 = [2, 3]
self.assert_total_order(l1, l1, 0)
self.assert_total_order(l1, l2, -1)
t1 = (1, 2)
t2 = (2, 3)
self.assert_total_order(t1, t1, 0)
self.assert_total_order(t1, t2, -1)
r1 = range(1, 2)
r2 = range(2, 2)
self.assert_equality_only(r1, r1, True)
self.assert_equality_only(r1, r2, False)
self.assert_equality_only(t1, l1, False)
self.assert_equality_only(l1, r1, False)
self.assert_equality_only(r1, t1, False)
def test_bytes(self):
"""Compare bytes and bytearray."""
bs1 = b'a1'
bs2 = b'b2'
self.assert_total_order(bs1, bs1, 0)
self.assert_total_order(bs1, bs2, -1)
ba1 = bytearray(b'a1')
ba2 = bytearray(b'b2')
self.assert_total_order(ba1, ba1, 0)
self.assert_total_order(ba1, ba2, -1)
self.assert_total_order(bs1, ba1, 0)
self.assert_total_order(bs1, ba2, -1)
self.assert_total_order(ba1, bs1, 0)
self.assert_total_order(ba1, bs2, -1)
def test_sets(self):
"""Compare set and frozenset."""
s1 = {1, 2}
s2 = {1, 2, 3}
self.assert_total_order(s1, s1, 0)
self.assert_total_order(s1, s2, -1)
f1 = frozenset(s1)
f2 = frozenset(s2)
self.assert_total_order(f1, f1, 0)
self.assert_total_order(f1, f2, -1)
self.assert_total_order(s1, f1, 0)
self.assert_total_order(s1, f2, -1)
self.assert_total_order(f1, s1, 0)
self.assert_total_order(f1, s2, -1)
def test_mappings(self):
""" Compare dict.
"""
d1 = {1: "a", 2: "b"}
d2 = {2: "b", 3: "c"}
d3 = {3: "c", 2: "b"}
self.assert_equality_only(d1, d1, True)
self.assert_equality_only(d1, d2, False)
self.assert_equality_only(d2, d3, True)
if __name__ == '__main__':
unittest.main()