bpo-37685: Fixed __eq__, __lt__ etc implementations in some classes. (GH-14952)

They now return NotImplemented for unsupported type of the other operand.
This commit is contained in:
Serhiy Storchaka 2019-08-08 08:42:54 +03:00 committed by GitHub
parent 4c69be22df
commit 662db125cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1295 additions and 1150 deletions

View File

@ -119,20 +119,24 @@ class TimerHandle(Handle):
return hash(self._when)
def __lt__(self, other):
if isinstance(other, TimerHandle):
return self._when < other._when
return NotImplemented
def __le__(self, other):
if self._when < other._when:
return True
return self.__eq__(other)
if isinstance(other, TimerHandle):
return self._when < other._when or self.__eq__(other)
return NotImplemented
def __gt__(self, other):
if isinstance(other, TimerHandle):
return self._when > other._when
return NotImplemented
def __ge__(self, other):
if self._when > other._when:
return True
return self.__eq__(other)
if isinstance(other, TimerHandle):
return self._when > other._when or self.__eq__(other)
return NotImplemented
def __eq__(self, other):
if isinstance(other, TimerHandle):
@ -142,10 +146,6 @@ class TimerHandle(Handle):
self._cancelled == other._cancelled)
return NotImplemented
def __ne__(self, other):
equal = self.__eq__(other)
return NotImplemented if equal is NotImplemented else not equal
def cancel(self):
if not self._cancelled:
self._loop._timer_handle_cancelled(self)

View File

@ -45,6 +45,14 @@ class VersionTestCase(unittest.TestCase):
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = StrictVersion(v1)._cmp(v2)
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = StrictVersion(v1)._cmp(object())
self.assertIs(res, NotImplemented,
'cmp(%s, %s) should be NotImplemented, got %s' %
(v1, v2, res))
def test_cmp(self):
@ -63,6 +71,14 @@ class VersionTestCase(unittest.TestCase):
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = LooseVersion(v1)._cmp(v2)
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = LooseVersion(v1)._cmp(object())
self.assertIs(res, NotImplemented,
'cmp(%s, %s) should be NotImplemented, got %s' %
(v1, v2, res))
def test_suite():
return unittest.makeSuite(VersionTestCase)

View File

@ -166,6 +166,8 @@ class StrictVersion (Version):
def _cmp (self, other):
if isinstance(other, str):
other = StrictVersion(other)
elif not isinstance(other, StrictVersion):
return NotImplemented
if self.version != other.version:
# numeric versions don't match
@ -331,6 +333,8 @@ class LooseVersion (Version):
def _cmp (self, other):
if isinstance(other, str):
other = LooseVersion(other)
elif not isinstance(other, LooseVersion):
return NotImplemented
if self.version == other.version:
return 0

View File

@ -97,8 +97,8 @@ class Address:
return self.addr_spec
def __eq__(self, other):
if type(other) != type(self):
return False
if not isinstance(other, Address):
return NotImplemented
return (self.display_name == other.display_name and
self.username == other.username and
self.domain == other.domain)
@ -150,8 +150,8 @@ class Group:
return "{}:{};".format(disp, adrstr)
def __eq__(self, other):
if type(other) != type(self):
return False
if not isinstance(other, Group):
return NotImplemented
return (self.display_name == other.display_name and
self.addresses == other.addresses)

View File

@ -371,7 +371,7 @@ class ModuleSpec:
self.cached == other.cached and
self.has_location == other.has_location)
except AttributeError:
return False
return NotImplemented
@property
def cached(self):

View File

@ -32,6 +32,7 @@ from asyncio import proactor_events
from asyncio import selector_events
from test.test_asyncio import utils as test_utils
from test import support
from test.support import ALWAYS_EQ, LARGEST, SMALLEST
def tearDownModule():
@ -2364,6 +2365,28 @@ class TimerTests(unittest.TestCase):
self.assertIs(NotImplemented, h1.__eq__(h3))
self.assertIs(NotImplemented, h1.__ne__(h3))
with self.assertRaises(TypeError):
h1 < ()
with self.assertRaises(TypeError):
h1 > ()
with self.assertRaises(TypeError):
h1 <= ()
with self.assertRaises(TypeError):
h1 >= ()
self.assertFalse(h1 == ())
self.assertTrue(h1 != ())
self.assertTrue(h1 == ALWAYS_EQ)
self.assertFalse(h1 != ALWAYS_EQ)
self.assertTrue(h1 < LARGEST)
self.assertFalse(h1 > LARGEST)
self.assertTrue(h1 <= LARGEST)
self.assertFalse(h1 >= LARGEST)
self.assertFalse(h1 < SMALLEST)
self.assertTrue(h1 > SMALLEST)
self.assertFalse(h1 <= SMALLEST)
self.assertTrue(h1 >= SMALLEST)
class AbstractEventLoopTests(unittest.TestCase):

View File

@ -7,6 +7,7 @@ from email.message import Message
from test.test_email import TestEmailBase, parameterize
from email import headerregistry
from email.headerregistry import Address, Group
from test.support import ALWAYS_EQ
DITTO = object()
@ -1525,6 +1526,24 @@ class TestAddressAndGroup(TestEmailBase):
self.assertEqual(m['to'], 'foo bar:;')
self.assertEqual(m['to'].addresses, g.addresses)
def test_address_comparison(self):
a = Address('foo', 'bar', 'example.com')
self.assertEqual(Address('foo', 'bar', 'example.com'), a)
self.assertNotEqual(Address('baz', 'bar', 'example.com'), a)
self.assertNotEqual(Address('foo', 'baz', 'example.com'), a)
self.assertNotEqual(Address('foo', 'bar', 'baz'), a)
self.assertFalse(a == object())
self.assertTrue(a == ALWAYS_EQ)
def test_group_comparison(self):
a = Address('foo', 'bar', 'example.com')
g = Group('foo bar', [a])
self.assertEqual(Group('foo bar', (a,)), g)
self.assertNotEqual(Group('baz', [a]), g)
self.assertNotEqual(Group('foo bar', []), g)
self.assertFalse(g == object())
self.assertTrue(g == ALWAYS_EQ)
class TestFolding(TestHeaderBase):

View File

@ -7,7 +7,7 @@ import sys
import unittest
import re
from test import support
from test.support import TESTFN, Error, captured_output, unlink, cpython_only
from test.support import TESTFN, Error, captured_output, unlink, cpython_only, ALWAYS_EQ
from test.support.script_helper import assert_python_ok
import textwrap
@ -887,6 +887,8 @@ class TestFrame(unittest.TestCase):
# operator fallbacks to FrameSummary.__eq__.
self.assertEqual(tuple(f), f)
self.assertIsNone(f.locals)
self.assertNotEqual(f, object())
self.assertEqual(f, ALWAYS_EQ)
def test_lazy_lines(self):
linecache.clearcache()
@ -1083,6 +1085,18 @@ class TestTracebackException(unittest.TestCase):
self.assertEqual(exc_info[0], exc.exc_type)
self.assertEqual(str(exc_info[1]), str(exc))
def test_comparison(self):
try:
1/0
except Exception:
exc_info = sys.exc_info()
exc = traceback.TracebackException(*exc_info)
exc2 = traceback.TracebackException(*exc_info)
self.assertIsNot(exc, exc2)
self.assertEqual(exc, exc2)
self.assertNotEqual(exc, object())
self.assertEqual(exc, ALWAYS_EQ)
def test_unhashable(self):
class UnhashableException(Exception):
def __eq__(self, other):

View File

@ -11,7 +11,7 @@ import time
import random
from test import support
from test.support import script_helper
from test.support import script_helper, ALWAYS_EQ
# Used in ReferencesTestCase.test_ref_created_during_del() .
ref_from_del = None
@ -794,6 +794,10 @@ class ReferencesTestCase(TestBase):
self.assertTrue(a != c)
self.assertTrue(a == d)
self.assertFalse(a != d)
self.assertFalse(a == x)
self.assertTrue(a != x)
self.assertTrue(a == ALWAYS_EQ)
self.assertFalse(a != ALWAYS_EQ)
del x, y, z
gc.collect()
for r in a, b, c:
@ -1102,6 +1106,9 @@ class WeakMethodTestCase(unittest.TestCase):
_ne(a, f)
_ne(b, e)
_ne(b, f)
# Compare with different types
_ne(a, x.some_method)
_eq(a, ALWAYS_EQ)
del x, y, z
gc.collect()
# Dead WeakMethods compare by identity

View File

@ -15,6 +15,7 @@ import re
import io
import contextlib
from test import support
from test.support import ALWAYS_EQ, LARGEST, SMALLEST
try:
import gzip
@ -530,14 +531,10 @@ class DateTimeTestCase(unittest.TestCase):
# some other types
dbytes = dstr.encode('ascii')
dtuple = now.timetuple()
with self.assertRaises(TypeError):
dtime == 1970
with self.assertRaises(TypeError):
dtime != dbytes
with self.assertRaises(TypeError):
dtime == bytearray(dbytes)
with self.assertRaises(TypeError):
dtime != dtuple
self.assertFalse(dtime == 1970)
self.assertTrue(dtime != dbytes)
self.assertFalse(dtime == bytearray(dbytes))
self.assertTrue(dtime != dtuple)
with self.assertRaises(TypeError):
dtime < float(1970)
with self.assertRaises(TypeError):
@ -547,6 +544,18 @@ class DateTimeTestCase(unittest.TestCase):
with self.assertRaises(TypeError):
dtime >= dtuple
self.assertTrue(dtime == ALWAYS_EQ)
self.assertFalse(dtime != ALWAYS_EQ)
self.assertTrue(dtime < LARGEST)
self.assertFalse(dtime > LARGEST)
self.assertTrue(dtime <= LARGEST)
self.assertFalse(dtime >= LARGEST)
self.assertFalse(dtime < SMALLEST)
self.assertTrue(dtime > SMALLEST)
self.assertFalse(dtime <= SMALLEST)
self.assertTrue(dtime >= SMALLEST)
class BinaryTestCase(unittest.TestCase):
# XXX What should str(Binary(b"\xff")) return? I'm chosing "\xff"

View File

@ -484,6 +484,8 @@ class Variable:
Note: if the Variable's master matters to behavior
also compare self._master == other._master
"""
if not isinstance(other, Variable):
return NotImplemented
return self.__class__.__name__ == other.__class__.__name__ \
and self._name == other._name

View File

@ -101,7 +101,9 @@ class Font:
return self.name
def __eq__(self, other):
return isinstance(other, Font) and self.name == other.name
if not isinstance(other, Font):
return NotImplemented
return self.name == other.name
def __getitem__(self, key):
return self.cget(key)

View File

@ -1,7 +1,7 @@
import unittest
import tkinter
from tkinter import font
from test.support import requires, run_unittest, gc_collect
from test.support import requires, run_unittest, gc_collect, ALWAYS_EQ
from tkinter.test.support import AbstractTkTest
requires('gui')
@ -70,6 +70,7 @@ class FontTest(AbstractTkTest, unittest.TestCase):
self.assertEqual(font1, font2)
self.assertNotEqual(font1, font1.copy())
self.assertNotEqual(font1, 0)
self.assertEqual(font1, ALWAYS_EQ)
def test_measure(self):
self.assertIsInstance(self.font.measure('abc'), int)

View File

@ -2,6 +2,7 @@ import unittest
import gc
from tkinter import (Variable, StringVar, IntVar, DoubleVar, BooleanVar, Tcl,
TclError)
from test.support import ALWAYS_EQ
class Var(Variable):
@ -59,11 +60,17 @@ class TestVariable(TestBase):
# values doesn't matter, only class and name are checked
v1 = Variable(self.root, name="abc")
v2 = Variable(self.root, name="abc")
self.assertIsNot(v1, v2)
self.assertEqual(v1, v2)
v3 = Variable(self.root, name="abc")
v4 = StringVar(self.root, name="abc")
self.assertNotEqual(v3, v4)
v3 = StringVar(self.root, name="abc")
self.assertNotEqual(v1, v3)
V = type('Variable', (), {})
self.assertNotEqual(v1, V())
self.assertNotEqual(v1, object())
self.assertEqual(v1, ALWAYS_EQ)
def test_invalid_name(self):
with self.assertRaises(TypeError):

View File

@ -538,7 +538,9 @@ class TracebackException:
self.__cause__._load_lines()
def __eq__(self, other):
if isinstance(other, TracebackException):
return self.__dict__ == other.__dict__
return NotImplemented
def __str__(self):
return self._str

View File

@ -43,6 +43,8 @@ class Statistic:
return hash((self.traceback, self.size, self.count))
def __eq__(self, other):
if not isinstance(other, Statistic):
return NotImplemented
return (self.traceback == other.traceback
and self.size == other.size
and self.count == other.count)
@ -84,6 +86,8 @@ class StatisticDiff:
self.count, self.count_diff))
def __eq__(self, other):
if not isinstance(other, StatisticDiff):
return NotImplemented
return (self.traceback == other.traceback
and self.size == other.size
and self.size_diff == other.size_diff
@ -153,9 +157,13 @@ class Frame:
return self._frame[1]
def __eq__(self, other):
if not isinstance(other, Frame):
return NotImplemented
return (self._frame == other._frame)
def __lt__(self, other):
if not isinstance(other, Frame):
return NotImplemented
return (self._frame < other._frame)
def __hash__(self):
@ -200,9 +208,13 @@ class Traceback(Sequence):
return hash(self._frames)
def __eq__(self, other):
if not isinstance(other, Traceback):
return NotImplemented
return (self._frames == other._frames)
def __lt__(self, other):
if not isinstance(other, Traceback):
return NotImplemented
return (self._frames < other._frames)
def __str__(self):
@ -271,6 +283,8 @@ class Trace:
return Traceback(self._trace[2])
def __eq__(self, other):
if not isinstance(other, Trace):
return NotImplemented
return (self._trace == other._trace)
def __hash__(self):
@ -303,6 +317,8 @@ class _Traces(Sequence):
return trace._trace in self._traces
def __eq__(self, other):
if not isinstance(other, _Traces):
return NotImplemented
return (self._traces == other._traces)
def __repr__(self):

View File

@ -2358,12 +2358,10 @@ class _Call(tuple):
def __eq__(self, other):
if other is ANY:
return True
try:
len_other = len(other)
except TypeError:
return False
return NotImplemented
self_name = ''
if len(self) == 2:

View File

@ -3,6 +3,7 @@ import re
import sys
import tempfile
from test.support import ALWAYS_EQ
import unittest
from unittest.test.testmock.support import is_instance
from unittest import mock
@ -322,6 +323,8 @@ class MockTest(unittest.TestCase):
self.assertFalse(mm != mock.ANY)
self.assertTrue(mock.ANY == mm)
self.assertFalse(mock.ANY != mm)
self.assertTrue(mm == ALWAYS_EQ)
self.assertFalse(mm != ALWAYS_EQ)
call1 = mock.call(mock.MagicMock())
call2 = mock.call(mock.ANY)
@ -330,6 +333,11 @@ class MockTest(unittest.TestCase):
self.assertTrue(call2 == call1)
self.assertFalse(call2 != call1)
self.assertTrue(call1 == ALWAYS_EQ)
self.assertFalse(call1 != ALWAYS_EQ)
self.assertFalse(call1 == 1)
self.assertTrue(call1 != 1)
def test_assert_called_with(self):
mock = Mock()

View File

@ -75,14 +75,14 @@ class WeakMethod(ref):
if not self._alive or not other._alive:
return self is other
return ref.__eq__(self, other) and self._func_ref == other._func_ref
return False
return NotImplemented
def __ne__(self, other):
if isinstance(other, WeakMethod):
if not self._alive or not other._alive:
return self is not other
return ref.__ne__(self, other) or self._func_ref != other._func_ref
return True
return NotImplemented
__hash__ = ref.__hash__

View File

@ -313,31 +313,38 @@ class DateTime:
s = self.timetuple()
o = other.timetuple()
else:
otype = (hasattr(other, "__class__")
and other.__class__.__name__
or type(other))
raise TypeError("Can't compare %s and %s" %
(self.__class__.__name__, otype))
s = self
o = NotImplemented
return s, o
def __lt__(self, other):
s, o = self.make_comparable(other)
if o is NotImplemented:
return NotImplemented
return s < o
def __le__(self, other):
s, o = self.make_comparable(other)
if o is NotImplemented:
return NotImplemented
return s <= o
def __gt__(self, other):
s, o = self.make_comparable(other)
if o is NotImplemented:
return NotImplemented
return s > o
def __ge__(self, other):
s, o = self.make_comparable(other)
if o is NotImplemented:
return NotImplemented
return s >= o
def __eq__(self, other):
s, o = self.make_comparable(other)
if o is NotImplemented:
return NotImplemented
return s == o
def timetuple(self):

View File

@ -0,0 +1,4 @@
Fixed ``__eq__``, ``__lt__`` etc implementations in some classes. They now
return :data:`NotImplemented` for unsupported type of the other operand.
This allows the other operand to play role (for example the equality
comparison with :data:`~unittest.mock.ANY` will return ``True``).

2212
Python/importlib.h generated

File diff suppressed because it is too large Load Diff

View File

@ -281,10 +281,14 @@ class PopupViewer:
self.__window.deiconify()
def __eq__(self, other):
if isinstance(self, PopupViewer):
return self.__menutext == other.__menutext
return NotImplemented
def __lt__(self, other):
if isinstance(self, PopupViewer):
return self.__menutext < other.__menutext
return NotImplemented
def make_view_popups(switchboard, root, extrapath):