diff --git a/Lib/_abcoll.py b/Lib/_abcoll.py index 82ded379a4e..3743ea82f04 100644 --- a/Lib/_abcoll.py +++ b/Lib/_abcoll.py @@ -369,8 +369,9 @@ class Mapping(Sized, Iterable, Container): __hash__ = None def __eq__(self, other): - return isinstance(other, Mapping) and \ - dict(self.items()) == dict(other.items()) + if not isinstance(other, Mapping): + return NotImplemented + return dict(self.items()) == dict(other.items()) def __ne__(self, other): return not (self == other) diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index d1fd60eec53..f48ae5a4c28 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -1,4 +1,4 @@ -import unittest, doctest +import unittest, doctest, operator from test import test_support from collections import namedtuple import pickle, cPickle, copy @@ -232,6 +232,37 @@ class ABCTestCase(unittest.TestCase): self.assertFalse(isinstance(C(), abc)) self.assertFalse(issubclass(C, abc)) + def validate_comparison(self, instance): + ops = ['lt', 'gt', 'le', 'ge', 'ne', 'or', 'and', 'xor', 'sub'] + operators = {} + for op in ops: + name = '__' + op + '__' + operators[name] = getattr(operator, name) + + class Other: + def __init__(self): + self.right_side = False + def __eq__(self, other): + self.right_side = True + return True + __lt__ = __eq__ + __gt__ = __eq__ + __le__ = __eq__ + __ge__ = __eq__ + __ne__ = __eq__ + __ror__ = __eq__ + __rand__ = __eq__ + __rxor__ = __eq__ + __rsub__ = __eq__ + + for name, op in operators.items(): + if not hasattr(instance, name): + continue + other = Other() + op(instance, other) + self.assertTrue(other.right_side,'Right side not called for %s.%s' + % (type(instance), name)) + class TestOneTrickPonyABCs(ABCTestCase): def test_Hashable(self): @@ -409,6 +440,14 @@ class TestCollectionABCs(ABCTestCase): self.failUnless(isinstance(sample(), Set)) self.failUnless(issubclass(sample, Set)) self.validate_abstract_methods(Set, '__contains__', '__iter__', '__len__') + class MySet(Set): + def __contains__(self, x): + return False + def __len__(self): + return 0 + def __iter__(self): + return iter([]) + self.validate_comparison(MySet()) def test_hash_Set(self): class OneTwoThreeSet(Set): @@ -472,6 +511,14 @@ class TestCollectionABCs(ABCTestCase): self.failUnless(issubclass(sample, Mapping)) self.validate_abstract_methods(Mapping, '__contains__', '__iter__', '__len__', '__getitem__') + class MyMapping(collections.Mapping): + def __len__(self): + return 0 + def __getitem__(self, i): + raise IndexError + def __iter__(self): + return iter(()) + self.validate_comparison(MyMapping()) def test_MutableMapping(self): for sample in [dict]: diff --git a/Misc/NEWS b/Misc/NEWS index 8e620cbdb4a..2fc7890de44 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -55,6 +55,9 @@ C-API Library ------- +- Issue #8729: Return NotImplemented from collections.Mapping.__eq__ when + comparing to a non-mapping. + - Issue #5918: Fix a crash in the parser module. - Issue #8688: Distutils now recalculates MANIFEST everytime.