bpo-37685: Fixed __eq__, __lt__ etc implementations in some classes. (GH-14952)
They now return NotImplemented for unsupported type of the other operand.
This commit is contained in:
parent
4c69be22df
commit
662db125cd
|
@ -119,20 +119,24 @@ class TimerHandle(Handle):
|
||||||
return hash(self._when)
|
return hash(self._when)
|
||||||
|
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
return self._when < other._when
|
if isinstance(other, TimerHandle):
|
||||||
|
return self._when < other._when
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
def __le__(self, other):
|
def __le__(self, other):
|
||||||
if self._when < other._when:
|
if isinstance(other, TimerHandle):
|
||||||
return True
|
return self._when < other._when or self.__eq__(other)
|
||||||
return self.__eq__(other)
|
return NotImplemented
|
||||||
|
|
||||||
def __gt__(self, other):
|
def __gt__(self, other):
|
||||||
return self._when > other._when
|
if isinstance(other, TimerHandle):
|
||||||
|
return self._when > other._when
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
def __ge__(self, other):
|
def __ge__(self, other):
|
||||||
if self._when > other._when:
|
if isinstance(other, TimerHandle):
|
||||||
return True
|
return self._when > other._when or self.__eq__(other)
|
||||||
return self.__eq__(other)
|
return NotImplemented
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if isinstance(other, TimerHandle):
|
if isinstance(other, TimerHandle):
|
||||||
|
@ -142,10 +146,6 @@ class TimerHandle(Handle):
|
||||||
self._cancelled == other._cancelled)
|
self._cancelled == other._cancelled)
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
def __ne__(self, other):
|
|
||||||
equal = self.__eq__(other)
|
|
||||||
return NotImplemented if equal is NotImplemented else not equal
|
|
||||||
|
|
||||||
def cancel(self):
|
def cancel(self):
|
||||||
if not self._cancelled:
|
if not self._cancelled:
|
||||||
self._loop._timer_handle_cancelled(self)
|
self._loop._timer_handle_cancelled(self)
|
||||||
|
|
|
@ -45,6 +45,14 @@ class VersionTestCase(unittest.TestCase):
|
||||||
self.assertEqual(res, wanted,
|
self.assertEqual(res, wanted,
|
||||||
'cmp(%s, %s) should be %s, got %s' %
|
'cmp(%s, %s) should be %s, got %s' %
|
||||||
(v1, v2, wanted, res))
|
(v1, v2, wanted, res))
|
||||||
|
res = StrictVersion(v1)._cmp(v2)
|
||||||
|
self.assertEqual(res, wanted,
|
||||||
|
'cmp(%s, %s) should be %s, got %s' %
|
||||||
|
(v1, v2, wanted, res))
|
||||||
|
res = StrictVersion(v1)._cmp(object())
|
||||||
|
self.assertIs(res, NotImplemented,
|
||||||
|
'cmp(%s, %s) should be NotImplemented, got %s' %
|
||||||
|
(v1, v2, res))
|
||||||
|
|
||||||
|
|
||||||
def test_cmp(self):
|
def test_cmp(self):
|
||||||
|
@ -63,6 +71,14 @@ class VersionTestCase(unittest.TestCase):
|
||||||
self.assertEqual(res, wanted,
|
self.assertEqual(res, wanted,
|
||||||
'cmp(%s, %s) should be %s, got %s' %
|
'cmp(%s, %s) should be %s, got %s' %
|
||||||
(v1, v2, wanted, res))
|
(v1, v2, wanted, res))
|
||||||
|
res = LooseVersion(v1)._cmp(v2)
|
||||||
|
self.assertEqual(res, wanted,
|
||||||
|
'cmp(%s, %s) should be %s, got %s' %
|
||||||
|
(v1, v2, wanted, res))
|
||||||
|
res = LooseVersion(v1)._cmp(object())
|
||||||
|
self.assertIs(res, NotImplemented,
|
||||||
|
'cmp(%s, %s) should be NotImplemented, got %s' %
|
||||||
|
(v1, v2, res))
|
||||||
|
|
||||||
def test_suite():
|
def test_suite():
|
||||||
return unittest.makeSuite(VersionTestCase)
|
return unittest.makeSuite(VersionTestCase)
|
||||||
|
|
|
@ -166,6 +166,8 @@ class StrictVersion (Version):
|
||||||
def _cmp (self, other):
|
def _cmp (self, other):
|
||||||
if isinstance(other, str):
|
if isinstance(other, str):
|
||||||
other = StrictVersion(other)
|
other = StrictVersion(other)
|
||||||
|
elif not isinstance(other, StrictVersion):
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
if self.version != other.version:
|
if self.version != other.version:
|
||||||
# numeric versions don't match
|
# numeric versions don't match
|
||||||
|
@ -331,6 +333,8 @@ class LooseVersion (Version):
|
||||||
def _cmp (self, other):
|
def _cmp (self, other):
|
||||||
if isinstance(other, str):
|
if isinstance(other, str):
|
||||||
other = LooseVersion(other)
|
other = LooseVersion(other)
|
||||||
|
elif not isinstance(other, LooseVersion):
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
if self.version == other.version:
|
if self.version == other.version:
|
||||||
return 0
|
return 0
|
||||||
|
|
|
@ -97,8 +97,8 @@ class Address:
|
||||||
return self.addr_spec
|
return self.addr_spec
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if type(other) != type(self):
|
if not isinstance(other, Address):
|
||||||
return False
|
return NotImplemented
|
||||||
return (self.display_name == other.display_name and
|
return (self.display_name == other.display_name and
|
||||||
self.username == other.username and
|
self.username == other.username and
|
||||||
self.domain == other.domain)
|
self.domain == other.domain)
|
||||||
|
@ -150,8 +150,8 @@ class Group:
|
||||||
return "{}:{};".format(disp, adrstr)
|
return "{}:{};".format(disp, adrstr)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if type(other) != type(self):
|
if not isinstance(other, Group):
|
||||||
return False
|
return NotImplemented
|
||||||
return (self.display_name == other.display_name and
|
return (self.display_name == other.display_name and
|
||||||
self.addresses == other.addresses)
|
self.addresses == other.addresses)
|
||||||
|
|
||||||
|
|
|
@ -371,7 +371,7 @@ class ModuleSpec:
|
||||||
self.cached == other.cached and
|
self.cached == other.cached and
|
||||||
self.has_location == other.has_location)
|
self.has_location == other.has_location)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return False
|
return NotImplemented
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cached(self):
|
def cached(self):
|
||||||
|
|
|
@ -32,6 +32,7 @@ from asyncio import proactor_events
|
||||||
from asyncio import selector_events
|
from asyncio import selector_events
|
||||||
from test.test_asyncio import utils as test_utils
|
from test.test_asyncio import utils as test_utils
|
||||||
from test import support
|
from test import support
|
||||||
|
from test.support import ALWAYS_EQ, LARGEST, SMALLEST
|
||||||
|
|
||||||
|
|
||||||
def tearDownModule():
|
def tearDownModule():
|
||||||
|
@ -2364,6 +2365,28 @@ class TimerTests(unittest.TestCase):
|
||||||
self.assertIs(NotImplemented, h1.__eq__(h3))
|
self.assertIs(NotImplemented, h1.__eq__(h3))
|
||||||
self.assertIs(NotImplemented, h1.__ne__(h3))
|
self.assertIs(NotImplemented, h1.__ne__(h3))
|
||||||
|
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
h1 < ()
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
h1 > ()
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
h1 <= ()
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
h1 >= ()
|
||||||
|
self.assertFalse(h1 == ())
|
||||||
|
self.assertTrue(h1 != ())
|
||||||
|
|
||||||
|
self.assertTrue(h1 == ALWAYS_EQ)
|
||||||
|
self.assertFalse(h1 != ALWAYS_EQ)
|
||||||
|
self.assertTrue(h1 < LARGEST)
|
||||||
|
self.assertFalse(h1 > LARGEST)
|
||||||
|
self.assertTrue(h1 <= LARGEST)
|
||||||
|
self.assertFalse(h1 >= LARGEST)
|
||||||
|
self.assertFalse(h1 < SMALLEST)
|
||||||
|
self.assertTrue(h1 > SMALLEST)
|
||||||
|
self.assertFalse(h1 <= SMALLEST)
|
||||||
|
self.assertTrue(h1 >= SMALLEST)
|
||||||
|
|
||||||
|
|
||||||
class AbstractEventLoopTests(unittest.TestCase):
|
class AbstractEventLoopTests(unittest.TestCase):
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ from email.message import Message
|
||||||
from test.test_email import TestEmailBase, parameterize
|
from test.test_email import TestEmailBase, parameterize
|
||||||
from email import headerregistry
|
from email import headerregistry
|
||||||
from email.headerregistry import Address, Group
|
from email.headerregistry import Address, Group
|
||||||
|
from test.support import ALWAYS_EQ
|
||||||
|
|
||||||
|
|
||||||
DITTO = object()
|
DITTO = object()
|
||||||
|
@ -1525,6 +1526,24 @@ class TestAddressAndGroup(TestEmailBase):
|
||||||
self.assertEqual(m['to'], 'foo bar:;')
|
self.assertEqual(m['to'], 'foo bar:;')
|
||||||
self.assertEqual(m['to'].addresses, g.addresses)
|
self.assertEqual(m['to'].addresses, g.addresses)
|
||||||
|
|
||||||
|
def test_address_comparison(self):
|
||||||
|
a = Address('foo', 'bar', 'example.com')
|
||||||
|
self.assertEqual(Address('foo', 'bar', 'example.com'), a)
|
||||||
|
self.assertNotEqual(Address('baz', 'bar', 'example.com'), a)
|
||||||
|
self.assertNotEqual(Address('foo', 'baz', 'example.com'), a)
|
||||||
|
self.assertNotEqual(Address('foo', 'bar', 'baz'), a)
|
||||||
|
self.assertFalse(a == object())
|
||||||
|
self.assertTrue(a == ALWAYS_EQ)
|
||||||
|
|
||||||
|
def test_group_comparison(self):
|
||||||
|
a = Address('foo', 'bar', 'example.com')
|
||||||
|
g = Group('foo bar', [a])
|
||||||
|
self.assertEqual(Group('foo bar', (a,)), g)
|
||||||
|
self.assertNotEqual(Group('baz', [a]), g)
|
||||||
|
self.assertNotEqual(Group('foo bar', []), g)
|
||||||
|
self.assertFalse(g == object())
|
||||||
|
self.assertTrue(g == ALWAYS_EQ)
|
||||||
|
|
||||||
|
|
||||||
class TestFolding(TestHeaderBase):
|
class TestFolding(TestHeaderBase):
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ import sys
|
||||||
import unittest
|
import unittest
|
||||||
import re
|
import re
|
||||||
from test import support
|
from test import support
|
||||||
from test.support import TESTFN, Error, captured_output, unlink, cpython_only
|
from test.support import TESTFN, Error, captured_output, unlink, cpython_only, ALWAYS_EQ
|
||||||
from test.support.script_helper import assert_python_ok
|
from test.support.script_helper import assert_python_ok
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
@ -887,6 +887,8 @@ class TestFrame(unittest.TestCase):
|
||||||
# operator fallbacks to FrameSummary.__eq__.
|
# operator fallbacks to FrameSummary.__eq__.
|
||||||
self.assertEqual(tuple(f), f)
|
self.assertEqual(tuple(f), f)
|
||||||
self.assertIsNone(f.locals)
|
self.assertIsNone(f.locals)
|
||||||
|
self.assertNotEqual(f, object())
|
||||||
|
self.assertEqual(f, ALWAYS_EQ)
|
||||||
|
|
||||||
def test_lazy_lines(self):
|
def test_lazy_lines(self):
|
||||||
linecache.clearcache()
|
linecache.clearcache()
|
||||||
|
@ -1083,6 +1085,18 @@ class TestTracebackException(unittest.TestCase):
|
||||||
self.assertEqual(exc_info[0], exc.exc_type)
|
self.assertEqual(exc_info[0], exc.exc_type)
|
||||||
self.assertEqual(str(exc_info[1]), str(exc))
|
self.assertEqual(str(exc_info[1]), str(exc))
|
||||||
|
|
||||||
|
def test_comparison(self):
|
||||||
|
try:
|
||||||
|
1/0
|
||||||
|
except Exception:
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
exc = traceback.TracebackException(*exc_info)
|
||||||
|
exc2 = traceback.TracebackException(*exc_info)
|
||||||
|
self.assertIsNot(exc, exc2)
|
||||||
|
self.assertEqual(exc, exc2)
|
||||||
|
self.assertNotEqual(exc, object())
|
||||||
|
self.assertEqual(exc, ALWAYS_EQ)
|
||||||
|
|
||||||
def test_unhashable(self):
|
def test_unhashable(self):
|
||||||
class UnhashableException(Exception):
|
class UnhashableException(Exception):
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
|
|
@ -11,7 +11,7 @@ import time
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from test import support
|
from test import support
|
||||||
from test.support import script_helper
|
from test.support import script_helper, ALWAYS_EQ
|
||||||
|
|
||||||
# Used in ReferencesTestCase.test_ref_created_during_del() .
|
# Used in ReferencesTestCase.test_ref_created_during_del() .
|
||||||
ref_from_del = None
|
ref_from_del = None
|
||||||
|
@ -794,6 +794,10 @@ class ReferencesTestCase(TestBase):
|
||||||
self.assertTrue(a != c)
|
self.assertTrue(a != c)
|
||||||
self.assertTrue(a == d)
|
self.assertTrue(a == d)
|
||||||
self.assertFalse(a != d)
|
self.assertFalse(a != d)
|
||||||
|
self.assertFalse(a == x)
|
||||||
|
self.assertTrue(a != x)
|
||||||
|
self.assertTrue(a == ALWAYS_EQ)
|
||||||
|
self.assertFalse(a != ALWAYS_EQ)
|
||||||
del x, y, z
|
del x, y, z
|
||||||
gc.collect()
|
gc.collect()
|
||||||
for r in a, b, c:
|
for r in a, b, c:
|
||||||
|
@ -1102,6 +1106,9 @@ class WeakMethodTestCase(unittest.TestCase):
|
||||||
_ne(a, f)
|
_ne(a, f)
|
||||||
_ne(b, e)
|
_ne(b, e)
|
||||||
_ne(b, f)
|
_ne(b, f)
|
||||||
|
# Compare with different types
|
||||||
|
_ne(a, x.some_method)
|
||||||
|
_eq(a, ALWAYS_EQ)
|
||||||
del x, y, z
|
del x, y, z
|
||||||
gc.collect()
|
gc.collect()
|
||||||
# Dead WeakMethods compare by identity
|
# Dead WeakMethods compare by identity
|
||||||
|
|
|
@ -15,6 +15,7 @@ import re
|
||||||
import io
|
import io
|
||||||
import contextlib
|
import contextlib
|
||||||
from test import support
|
from test import support
|
||||||
|
from test.support import ALWAYS_EQ, LARGEST, SMALLEST
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import gzip
|
import gzip
|
||||||
|
@ -530,14 +531,10 @@ class DateTimeTestCase(unittest.TestCase):
|
||||||
# some other types
|
# some other types
|
||||||
dbytes = dstr.encode('ascii')
|
dbytes = dstr.encode('ascii')
|
||||||
dtuple = now.timetuple()
|
dtuple = now.timetuple()
|
||||||
with self.assertRaises(TypeError):
|
self.assertFalse(dtime == 1970)
|
||||||
dtime == 1970
|
self.assertTrue(dtime != dbytes)
|
||||||
with self.assertRaises(TypeError):
|
self.assertFalse(dtime == bytearray(dbytes))
|
||||||
dtime != dbytes
|
self.assertTrue(dtime != dtuple)
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
dtime == bytearray(dbytes)
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
dtime != dtuple
|
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
dtime < float(1970)
|
dtime < float(1970)
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
|
@ -547,6 +544,18 @@ class DateTimeTestCase(unittest.TestCase):
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
dtime >= dtuple
|
dtime >= dtuple
|
||||||
|
|
||||||
|
self.assertTrue(dtime == ALWAYS_EQ)
|
||||||
|
self.assertFalse(dtime != ALWAYS_EQ)
|
||||||
|
self.assertTrue(dtime < LARGEST)
|
||||||
|
self.assertFalse(dtime > LARGEST)
|
||||||
|
self.assertTrue(dtime <= LARGEST)
|
||||||
|
self.assertFalse(dtime >= LARGEST)
|
||||||
|
self.assertFalse(dtime < SMALLEST)
|
||||||
|
self.assertTrue(dtime > SMALLEST)
|
||||||
|
self.assertFalse(dtime <= SMALLEST)
|
||||||
|
self.assertTrue(dtime >= SMALLEST)
|
||||||
|
|
||||||
|
|
||||||
class BinaryTestCase(unittest.TestCase):
|
class BinaryTestCase(unittest.TestCase):
|
||||||
|
|
||||||
# XXX What should str(Binary(b"\xff")) return? I'm chosing "\xff"
|
# XXX What should str(Binary(b"\xff")) return? I'm chosing "\xff"
|
||||||
|
|
|
@ -484,6 +484,8 @@ class Variable:
|
||||||
Note: if the Variable's master matters to behavior
|
Note: if the Variable's master matters to behavior
|
||||||
also compare self._master == other._master
|
also compare self._master == other._master
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(other, Variable):
|
||||||
|
return NotImplemented
|
||||||
return self.__class__.__name__ == other.__class__.__name__ \
|
return self.__class__.__name__ == other.__class__.__name__ \
|
||||||
and self._name == other._name
|
and self._name == other._name
|
||||||
|
|
||||||
|
|
|
@ -101,7 +101,9 @@ class Font:
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return isinstance(other, Font) and self.name == other.name
|
if not isinstance(other, Font):
|
||||||
|
return NotImplemented
|
||||||
|
return self.name == other.name
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
return self.cget(key)
|
return self.cget(key)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
import tkinter
|
import tkinter
|
||||||
from tkinter import font
|
from tkinter import font
|
||||||
from test.support import requires, run_unittest, gc_collect
|
from test.support import requires, run_unittest, gc_collect, ALWAYS_EQ
|
||||||
from tkinter.test.support import AbstractTkTest
|
from tkinter.test.support import AbstractTkTest
|
||||||
|
|
||||||
requires('gui')
|
requires('gui')
|
||||||
|
@ -70,6 +70,7 @@ class FontTest(AbstractTkTest, unittest.TestCase):
|
||||||
self.assertEqual(font1, font2)
|
self.assertEqual(font1, font2)
|
||||||
self.assertNotEqual(font1, font1.copy())
|
self.assertNotEqual(font1, font1.copy())
|
||||||
self.assertNotEqual(font1, 0)
|
self.assertNotEqual(font1, 0)
|
||||||
|
self.assertEqual(font1, ALWAYS_EQ)
|
||||||
|
|
||||||
def test_measure(self):
|
def test_measure(self):
|
||||||
self.assertIsInstance(self.font.measure('abc'), int)
|
self.assertIsInstance(self.font.measure('abc'), int)
|
||||||
|
|
|
@ -2,6 +2,7 @@ import unittest
|
||||||
import gc
|
import gc
|
||||||
from tkinter import (Variable, StringVar, IntVar, DoubleVar, BooleanVar, Tcl,
|
from tkinter import (Variable, StringVar, IntVar, DoubleVar, BooleanVar, Tcl,
|
||||||
TclError)
|
TclError)
|
||||||
|
from test.support import ALWAYS_EQ
|
||||||
|
|
||||||
|
|
||||||
class Var(Variable):
|
class Var(Variable):
|
||||||
|
@ -59,11 +60,17 @@ class TestVariable(TestBase):
|
||||||
# values doesn't matter, only class and name are checked
|
# values doesn't matter, only class and name are checked
|
||||||
v1 = Variable(self.root, name="abc")
|
v1 = Variable(self.root, name="abc")
|
||||||
v2 = Variable(self.root, name="abc")
|
v2 = Variable(self.root, name="abc")
|
||||||
|
self.assertIsNot(v1, v2)
|
||||||
self.assertEqual(v1, v2)
|
self.assertEqual(v1, v2)
|
||||||
|
|
||||||
v3 = Variable(self.root, name="abc")
|
v3 = StringVar(self.root, name="abc")
|
||||||
v4 = StringVar(self.root, name="abc")
|
self.assertNotEqual(v1, v3)
|
||||||
self.assertNotEqual(v3, v4)
|
|
||||||
|
V = type('Variable', (), {})
|
||||||
|
self.assertNotEqual(v1, V())
|
||||||
|
|
||||||
|
self.assertNotEqual(v1, object())
|
||||||
|
self.assertEqual(v1, ALWAYS_EQ)
|
||||||
|
|
||||||
def test_invalid_name(self):
|
def test_invalid_name(self):
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
|
|
|
@ -538,7 +538,9 @@ class TracebackException:
|
||||||
self.__cause__._load_lines()
|
self.__cause__._load_lines()
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.__dict__ == other.__dict__
|
if isinstance(other, TracebackException):
|
||||||
|
return self.__dict__ == other.__dict__
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self._str
|
return self._str
|
||||||
|
|
|
@ -43,6 +43,8 @@ class Statistic:
|
||||||
return hash((self.traceback, self.size, self.count))
|
return hash((self.traceback, self.size, self.count))
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, Statistic):
|
||||||
|
return NotImplemented
|
||||||
return (self.traceback == other.traceback
|
return (self.traceback == other.traceback
|
||||||
and self.size == other.size
|
and self.size == other.size
|
||||||
and self.count == other.count)
|
and self.count == other.count)
|
||||||
|
@ -84,6 +86,8 @@ class StatisticDiff:
|
||||||
self.count, self.count_diff))
|
self.count, self.count_diff))
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, StatisticDiff):
|
||||||
|
return NotImplemented
|
||||||
return (self.traceback == other.traceback
|
return (self.traceback == other.traceback
|
||||||
and self.size == other.size
|
and self.size == other.size
|
||||||
and self.size_diff == other.size_diff
|
and self.size_diff == other.size_diff
|
||||||
|
@ -153,9 +157,13 @@ class Frame:
|
||||||
return self._frame[1]
|
return self._frame[1]
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, Frame):
|
||||||
|
return NotImplemented
|
||||||
return (self._frame == other._frame)
|
return (self._frame == other._frame)
|
||||||
|
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
|
if not isinstance(other, Frame):
|
||||||
|
return NotImplemented
|
||||||
return (self._frame < other._frame)
|
return (self._frame < other._frame)
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
|
@ -200,9 +208,13 @@ class Traceback(Sequence):
|
||||||
return hash(self._frames)
|
return hash(self._frames)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, Traceback):
|
||||||
|
return NotImplemented
|
||||||
return (self._frames == other._frames)
|
return (self._frames == other._frames)
|
||||||
|
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
|
if not isinstance(other, Traceback):
|
||||||
|
return NotImplemented
|
||||||
return (self._frames < other._frames)
|
return (self._frames < other._frames)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -271,6 +283,8 @@ class Trace:
|
||||||
return Traceback(self._trace[2])
|
return Traceback(self._trace[2])
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, Trace):
|
||||||
|
return NotImplemented
|
||||||
return (self._trace == other._trace)
|
return (self._trace == other._trace)
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
|
@ -303,6 +317,8 @@ class _Traces(Sequence):
|
||||||
return trace._trace in self._traces
|
return trace._trace in self._traces
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, _Traces):
|
||||||
|
return NotImplemented
|
||||||
return (self._traces == other._traces)
|
return (self._traces == other._traces)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
|
|
@ -2358,12 +2358,10 @@ class _Call(tuple):
|
||||||
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if other is ANY:
|
|
||||||
return True
|
|
||||||
try:
|
try:
|
||||||
len_other = len(other)
|
len_other = len(other)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
return False
|
return NotImplemented
|
||||||
|
|
||||||
self_name = ''
|
self_name = ''
|
||||||
if len(self) == 2:
|
if len(self) == 2:
|
||||||
|
|
|
@ -3,6 +3,7 @@ import re
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
|
from test.support import ALWAYS_EQ
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.test.testmock.support import is_instance
|
from unittest.test.testmock.support import is_instance
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
@ -322,6 +323,8 @@ class MockTest(unittest.TestCase):
|
||||||
self.assertFalse(mm != mock.ANY)
|
self.assertFalse(mm != mock.ANY)
|
||||||
self.assertTrue(mock.ANY == mm)
|
self.assertTrue(mock.ANY == mm)
|
||||||
self.assertFalse(mock.ANY != mm)
|
self.assertFalse(mock.ANY != mm)
|
||||||
|
self.assertTrue(mm == ALWAYS_EQ)
|
||||||
|
self.assertFalse(mm != ALWAYS_EQ)
|
||||||
|
|
||||||
call1 = mock.call(mock.MagicMock())
|
call1 = mock.call(mock.MagicMock())
|
||||||
call2 = mock.call(mock.ANY)
|
call2 = mock.call(mock.ANY)
|
||||||
|
@ -330,6 +333,11 @@ class MockTest(unittest.TestCase):
|
||||||
self.assertTrue(call2 == call1)
|
self.assertTrue(call2 == call1)
|
||||||
self.assertFalse(call2 != call1)
|
self.assertFalse(call2 != call1)
|
||||||
|
|
||||||
|
self.assertTrue(call1 == ALWAYS_EQ)
|
||||||
|
self.assertFalse(call1 != ALWAYS_EQ)
|
||||||
|
self.assertFalse(call1 == 1)
|
||||||
|
self.assertTrue(call1 != 1)
|
||||||
|
|
||||||
|
|
||||||
def test_assert_called_with(self):
|
def test_assert_called_with(self):
|
||||||
mock = Mock()
|
mock = Mock()
|
||||||
|
|
|
@ -75,14 +75,14 @@ class WeakMethod(ref):
|
||||||
if not self._alive or not other._alive:
|
if not self._alive or not other._alive:
|
||||||
return self is other
|
return self is other
|
||||||
return ref.__eq__(self, other) and self._func_ref == other._func_ref
|
return ref.__eq__(self, other) and self._func_ref == other._func_ref
|
||||||
return False
|
return NotImplemented
|
||||||
|
|
||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
if isinstance(other, WeakMethod):
|
if isinstance(other, WeakMethod):
|
||||||
if not self._alive or not other._alive:
|
if not self._alive or not other._alive:
|
||||||
return self is not other
|
return self is not other
|
||||||
return ref.__ne__(self, other) or self._func_ref != other._func_ref
|
return ref.__ne__(self, other) or self._func_ref != other._func_ref
|
||||||
return True
|
return NotImplemented
|
||||||
|
|
||||||
__hash__ = ref.__hash__
|
__hash__ = ref.__hash__
|
||||||
|
|
||||||
|
|
|
@ -313,31 +313,38 @@ class DateTime:
|
||||||
s = self.timetuple()
|
s = self.timetuple()
|
||||||
o = other.timetuple()
|
o = other.timetuple()
|
||||||
else:
|
else:
|
||||||
otype = (hasattr(other, "__class__")
|
s = self
|
||||||
and other.__class__.__name__
|
o = NotImplemented
|
||||||
or type(other))
|
|
||||||
raise TypeError("Can't compare %s and %s" %
|
|
||||||
(self.__class__.__name__, otype))
|
|
||||||
return s, o
|
return s, o
|
||||||
|
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
s, o = self.make_comparable(other)
|
s, o = self.make_comparable(other)
|
||||||
|
if o is NotImplemented:
|
||||||
|
return NotImplemented
|
||||||
return s < o
|
return s < o
|
||||||
|
|
||||||
def __le__(self, other):
|
def __le__(self, other):
|
||||||
s, o = self.make_comparable(other)
|
s, o = self.make_comparable(other)
|
||||||
|
if o is NotImplemented:
|
||||||
|
return NotImplemented
|
||||||
return s <= o
|
return s <= o
|
||||||
|
|
||||||
def __gt__(self, other):
|
def __gt__(self, other):
|
||||||
s, o = self.make_comparable(other)
|
s, o = self.make_comparable(other)
|
||||||
|
if o is NotImplemented:
|
||||||
|
return NotImplemented
|
||||||
return s > o
|
return s > o
|
||||||
|
|
||||||
def __ge__(self, other):
|
def __ge__(self, other):
|
||||||
s, o = self.make_comparable(other)
|
s, o = self.make_comparable(other)
|
||||||
|
if o is NotImplemented:
|
||||||
|
return NotImplemented
|
||||||
return s >= o
|
return s >= o
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
s, o = self.make_comparable(other)
|
s, o = self.make_comparable(other)
|
||||||
|
if o is NotImplemented:
|
||||||
|
return NotImplemented
|
||||||
return s == o
|
return s == o
|
||||||
|
|
||||||
def timetuple(self):
|
def timetuple(self):
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
Fixed ``__eq__``, ``__lt__`` etc implementations in some classes. They now
|
||||||
|
return :data:`NotImplemented` for unsupported type of the other operand.
|
||||||
|
This allows the other operand to play role (for example the equality
|
||||||
|
comparison with :data:`~unittest.mock.ANY` will return ``True``).
|
File diff suppressed because it is too large
Load Diff
|
@ -281,10 +281,14 @@ class PopupViewer:
|
||||||
self.__window.deiconify()
|
self.__window.deiconify()
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.__menutext == other.__menutext
|
if isinstance(self, PopupViewer):
|
||||||
|
return self.__menutext == other.__menutext
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
return self.__menutext < other.__menutext
|
if isinstance(self, PopupViewer):
|
||||||
|
return self.__menutext < other.__menutext
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
def make_view_popups(switchboard, root, extrapath):
|
def make_view_popups(switchboard, root, extrapath):
|
||||||
|
|
Loading…
Reference in New Issue