Anti-registration of various ABC methods.

- Issue #25958: Support "anti-registration" of special methods from
  various ABCs, like __hash__, __iter__ or __len__.  All these (and
  several more) can be set to None in an implementation class and the
  behavior will be as if the method is not defined at all.
  (Previously, this mechanism existed only for __hash__, to make
  mutable classes unhashable.)  Code contributed by Andrew Barnert and
  Ivan Levkivskyi.
This commit is contained in:
Guido van Rossum 2016-08-18 09:22:23 -07:00
parent 0a6996d87d
commit 97c1adf393
15 changed files with 300 additions and 62 deletions

View File

@ -114,7 +114,8 @@ ABC Inherits from Abstract Methods Mixin
.. class:: Reversible
ABC for classes that provide the :meth:`__reversed__` method.
ABC for iterable classes that also provide the :meth:`__reversed__`
method.
.. versionadded:: 3.6

View File

@ -1063,6 +1063,12 @@ to ``type(x).__getitem__(x, i)``. Except where mentioned, attempts to execute a
operation raise an exception when no appropriate method is defined (typically
:exc:`AttributeError` or :exc:`TypeError`).
Setting a special method to ``None`` indicates that the corresponding
operation is not available. For example, if a class sets
:meth:`__iter__` to ``None``, the class is not iterable, so calling
:func:`iter` on its instances will raise a :exc:`TypeError` (without
falling back to :meth:`__getitem__`). [#]_
When implementing a class that emulates any built-in type, it is important that
the emulation only be implemented to the degree that it makes sense for the
object being modelled. For example, some sequences may work well with retrieval
@ -2113,7 +2119,7 @@ left undefined.
(``+``, ``-``, ``*``, ``@``, ``/``, ``//``, ``%``, :func:`divmod`,
:func:`pow`, ``**``, ``<<``, ``>>``, ``&``, ``^``, ``|``) with reflected
(swapped) operands. These functions are only called if the left operand does
not support the corresponding operation and the operands are of different
not support the corresponding operation [#]_ and the operands are of different
types. [#]_ For instance, to evaluate the expression ``x - y``, where *y* is
an instance of a class that has an :meth:`__rsub__` method, ``y.__rsub__(x)``
is called if ``x.__sub__(y)`` returns *NotImplemented*.
@ -2529,6 +2535,17 @@ An example of an asynchronous context manager class::
controlled conditions. It generally isn't a good idea though, since it can
lead to some very strange behaviour if it is handled incorrectly.
.. [#] The :meth:`__hash__`, :meth:`__iter__`, :meth:`__reversed__`, and
:meth:`__contains__` methods have special handling for this; others
will still raise a :exc:`TypeError`, but may do so by relying on
the behavior that ``None`` is not callable.
.. [#] "Does not support" here means that the class has no such method, or
the method returns ``NotImplemented``. Do not set the method to
``None`` if you want to force fallback to the right operand's reflected
method--that will instead have the opposite effect of explicitly
*blocking* such fallback.
.. [#] For operands of the same type, it is assumed that if the non-reflected method
(such as :meth:`__add__`) fails the operation is not supported, which is why the
reflected method is not called.

View File

@ -62,6 +62,18 @@ del _coro
### ONE-TRICK PONIES ###
def _check_methods(C, *methods):
mro = C.__mro__
for method in methods:
for B in mro:
if method in B.__dict__:
if B.__dict__[method] is None:
return NotImplemented
break
else:
return NotImplemented
return True
class Hashable(metaclass=ABCMeta):
__slots__ = ()
@ -73,11 +85,7 @@ class Hashable(metaclass=ABCMeta):
@classmethod
def __subclasshook__(cls, C):
if cls is Hashable:
for B in C.__mro__:
if "__hash__" in B.__dict__:
if B.__dict__["__hash__"]:
return True
break
return _check_methods(C, "__hash__")
return NotImplemented
@ -92,11 +100,7 @@ class Awaitable(metaclass=ABCMeta):
@classmethod
def __subclasshook__(cls, C):
if cls is Awaitable:
for B in C.__mro__:
if "__await__" in B.__dict__:
if B.__dict__["__await__"]:
return True
break
return _check_methods(C, "__await__")
return NotImplemented
@ -137,14 +141,7 @@ class Coroutine(Awaitable):
@classmethod
def __subclasshook__(cls, C):
if cls is Coroutine:
mro = C.__mro__
for method in ('__await__', 'send', 'throw', 'close'):
for base in mro:
if method in base.__dict__:
break
else:
return NotImplemented
return True
return _check_methods(C, '__await__', 'send', 'throw', 'close')
return NotImplemented
@ -162,8 +159,7 @@ class AsyncIterable(metaclass=ABCMeta):
@classmethod
def __subclasshook__(cls, C):
if cls is AsyncIterable:
if any("__aiter__" in B.__dict__ for B in C.__mro__):
return True
return _check_methods(C, "__aiter__")
return NotImplemented
@ -182,9 +178,7 @@ class AsyncIterator(AsyncIterable):
@classmethod
def __subclasshook__(cls, C):
if cls is AsyncIterator:
if (any("__anext__" in B.__dict__ for B in C.__mro__) and
any("__aiter__" in B.__dict__ for B in C.__mro__)):
return True
return _check_methods(C, "__anext__", "__aiter__")
return NotImplemented
@ -200,8 +194,7 @@ class Iterable(metaclass=ABCMeta):
@classmethod
def __subclasshook__(cls, C):
if cls is Iterable:
if any("__iter__" in B.__dict__ for B in C.__mro__):
return True
return _check_methods(C, "__iter__")
return NotImplemented
@ -220,9 +213,7 @@ class Iterator(Iterable):
@classmethod
def __subclasshook__(cls, C):
if cls is Iterator:
if (any("__next__" in B.__dict__ for B in C.__mro__) and
any("__iter__" in B.__dict__ for B in C.__mro__)):
return True
return _check_methods(C, '__iter__', '__next__')
return NotImplemented
Iterator.register(bytes_iterator)
@ -246,16 +237,13 @@ class Reversible(Iterable):
@abstractmethod
def __reversed__(self):
return NotImplemented
while False:
yield None
@classmethod
def __subclasshook__(cls, C):
if cls is Reversible:
for B in C.__mro__:
if "__reversed__" in B.__dict__:
if B.__dict__["__reversed__"] is not None:
return True
break
return _check_methods(C, "__reversed__", "__iter__")
return NotImplemented
@ -302,16 +290,9 @@ class Generator(Iterator):
@classmethod
def __subclasshook__(cls, C):
if cls is Generator:
mro = C.__mro__
for method in ('__iter__', '__next__', 'send', 'throw', 'close'):
for base in mro:
if method in base.__dict__:
break
else:
return _check_methods(C, '__iter__', '__next__',
'send', 'throw', 'close')
return NotImplemented
return True
return NotImplemented
Generator.register(generator)
@ -327,8 +308,7 @@ class Sized(metaclass=ABCMeta):
@classmethod
def __subclasshook__(cls, C):
if cls is Sized:
if any("__len__" in B.__dict__ for B in C.__mro__):
return True
return _check_methods(C, "__len__")
return NotImplemented
@ -343,8 +323,7 @@ class Container(metaclass=ABCMeta):
@classmethod
def __subclasshook__(cls, C):
if cls is Container:
if any("__contains__" in B.__dict__ for B in C.__mro__):
return True
return _check_methods(C, "__contains__")
return NotImplemented
@ -359,8 +338,7 @@ class Callable(metaclass=ABCMeta):
@classmethod
def __subclasshook__(cls, C):
if cls is Callable:
if any("__call__" in B.__dict__ for B in C.__mro__):
return True
return _check_methods(C, "__call__")
return NotImplemented
@ -640,6 +618,8 @@ class Mapping(Sized, Iterable, Container):
return NotImplemented
return dict(self.items()) == dict(other.items())
__reversed__ = None
Mapping.register(mappingproxy)

View File

@ -83,6 +83,10 @@ class AugAssignTest(unittest.TestCase):
def __iadd__(self, val):
return aug_test3(self.val + val)
class aug_test4(aug_test3):
"""Blocks inheritance, and fallback to __add__"""
__iadd__ = None
x = aug_test(1)
y = x
x += 10
@ -106,6 +110,10 @@ class AugAssignTest(unittest.TestCase):
self.assertTrue(y is not x)
self.assertEqual(x.val, 13)
x = aug_test4(4)
with self.assertRaises(TypeError):
x += 10
def testCustomMethods2(test_self):
output = []

View File

@ -2,7 +2,7 @@
import unittest
from test import support
from operator import eq, le
from operator import eq, le, ne
from abc import ABCMeta
def gcd(a, b):
@ -388,6 +388,54 @@ class OperationOrderTests(unittest.TestCase):
self.assertEqual(op_sequence(eq, B, V), ['B.__eq__', 'V.__eq__'])
self.assertEqual(op_sequence(le, B, V), ['B.__le__', 'V.__ge__'])
class SupEq(object):
"""Class that can test equality"""
def __eq__(self, other):
return True
class S(SupEq):
"""Subclass of SupEq that should fail"""
__eq__ = None
class F(object):
"""Independent class that should fall back"""
class X(object):
"""Independent class that should fail"""
__eq__ = None
class SN(SupEq):
"""Subclass of SupEq that can test equality, but not non-equality"""
__ne__ = None
class XN:
"""Independent class that can test equality, but not non-equality"""
def __eq__(self, other):
return True
__ne__ = None
class FallbackBlockingTests(unittest.TestCase):
"""Unit tests for None method blocking"""
def test_fallback_rmethod_blocking(self):
e, f, s, x = SupEq(), F(), S(), X()
self.assertEqual(e, e)
self.assertEqual(e, f)
self.assertEqual(f, e)
# left operand is checked first
self.assertEqual(e, x)
self.assertRaises(TypeError, eq, x, e)
# S is a subclass, so it's always checked first
self.assertRaises(TypeError, eq, e, s)
self.assertRaises(TypeError, eq, s, e)
def test_fallback_ne_blocking(self):
e, sn, xn = SupEq(), SN(), XN()
self.assertFalse(e != e)
self.assertRaises(TypeError, ne, e, sn)
self.assertRaises(TypeError, ne, sn, e)
self.assertFalse(e != xn)
self.assertRaises(TypeError, ne, xn, e)
if __name__ == "__main__":
unittest.main()

View File

@ -333,6 +333,17 @@ class BoolTest(unittest.TestCase):
except (Exception) as e_len:
self.assertEqual(str(e_bool), str(e_len))
def test_blocked(self):
class A:
__bool__ = None
self.assertRaises(TypeError, bool, A())
class B:
def __len__(self):
return 10
__bool__ = None
self.assertRaises(TypeError, bool, B())
def test_real_and_imag(self):
self.assertEqual(True.real, 1)
self.assertEqual(True.imag, 0)

View File

@ -843,6 +843,36 @@ class BytesTest(BaseBytesTest, unittest.TestCase):
self.assertRaises(OverflowError,
PyBytes_FromFormat, b'%c', c_int(256))
def test_bytes_blocking(self):
class IterationBlocked(list):
__bytes__ = None
i = [0, 1, 2, 3]
self.assertEqual(bytes(i), b'\x00\x01\x02\x03')
self.assertRaises(TypeError, bytes, IterationBlocked(i))
# At least in CPython, because bytes.__new__ and the C API
# PyBytes_FromObject have different fallback rules, integer
# fallback is handled specially, so test separately.
class IntBlocked(int):
__bytes__ = None
self.assertEqual(bytes(3), b'\0\0\0')
self.assertRaises(TypeError, bytes, IntBlocked(3))
# While there is no separately-defined rule for handling bytes
# subclasses differently from other buffer-interface classes,
# an implementation may well special-case them (as CPython 2.x
# str did), so test them separately.
class BytesSubclassBlocked(bytes):
__bytes__ = None
self.assertEqual(bytes(b'ab'), b'ab')
self.assertRaises(TypeError, bytes, BytesSubclassBlocked(b'ab'))
class BufferBlocked(bytearray):
__bytes__ = None
ba, bb = bytearray(b'ab'), BufferBlocked(b'ab')
self.assertEqual(bytes(ba), b'ab')
self.assertRaises(TypeError, bytes, bb)
class ByteArrayTest(BaseBytesTest, unittest.TestCase):
type2test = bytearray

View File

@ -499,6 +499,9 @@ class ABCTestCase(unittest.TestCase):
self.assertTrue(other.right_side,'Right side not called for %s.%s'
% (type(instance), name))
def _test_gen():
yield
class TestOneTrickPonyABCs(ABCTestCase):
def test_Awaitable(self):
@ -686,7 +689,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
samples = [bytes(), str(),
tuple(), list(), set(), frozenset(), dict(),
dict().keys(), dict().items(), dict().values(),
(lambda: (yield))(),
_test_gen(),
(x for x in []),
]
for x in samples:
@ -700,6 +703,15 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertFalse(issubclass(str, I))
self.validate_abstract_methods(Iterable, '__iter__')
self.validate_isinstance(Iterable, '__iter__')
# Check None blocking
class It:
def __iter__(self): return iter([])
class ItBlocked(It):
__iter__ = None
self.assertTrue(issubclass(It, Iterable))
self.assertTrue(isinstance(It(), Iterable))
self.assertFalse(issubclass(ItBlocked, Iterable))
self.assertFalse(isinstance(ItBlocked(), Iterable))
def test_Reversible(self):
# Check some non-reversibles
@ -707,8 +719,18 @@ class TestOneTrickPonyABCs(ABCTestCase):
for x in non_samples:
self.assertNotIsInstance(x, Reversible)
self.assertFalse(issubclass(type(x), Reversible), repr(type(x)))
# Check some reversibles
samples = [tuple(), list()]
# Check some non-reversible iterables
non_reversibles = [dict().keys(), dict().items(), dict().values(),
Counter(), Counter().keys(), Counter().items(),
Counter().values(), _test_gen(),
(x for x in []), iter([]), reversed([])]
for x in non_reversibles:
self.assertNotIsInstance(x, Reversible)
self.assertFalse(issubclass(type(x), Reversible), repr(type(x)))
# Check some reversible iterables
samples = [bytes(), str(), tuple(), list(), OrderedDict(),
OrderedDict().keys(), OrderedDict().items(),
OrderedDict().values()]
for x in samples:
self.assertIsInstance(x, Reversible)
self.assertTrue(issubclass(type(x), Reversible), repr(type(x)))
@ -725,6 +747,29 @@ class TestOneTrickPonyABCs(ABCTestCase):
self.assertEqual(list(reversed(R())), [])
self.assertFalse(issubclass(float, R))
self.validate_abstract_methods(Reversible, '__reversed__', '__iter__')
# Check reversible non-iterable (which is not Reversible)
class RevNoIter:
def __reversed__(self): return reversed([])
class RevPlusIter(RevNoIter):
def __iter__(self): return iter([])
self.assertFalse(issubclass(RevNoIter, Reversible))
self.assertFalse(isinstance(RevNoIter(), Reversible))
self.assertTrue(issubclass(RevPlusIter, Reversible))
self.assertTrue(isinstance(RevPlusIter(), Reversible))
# Check None blocking
class Rev:
def __iter__(self): return iter([])
def __reversed__(self): return reversed([])
class RevItBlocked(Rev):
__iter__ = None
class RevRevBlocked(Rev):
__reversed__ = None
self.assertTrue(issubclass(Rev, Reversible))
self.assertTrue(isinstance(Rev(), Reversible))
self.assertFalse(issubclass(RevItBlocked, Reversible))
self.assertFalse(isinstance(RevItBlocked(), Reversible))
self.assertFalse(issubclass(RevRevBlocked, Reversible))
self.assertFalse(isinstance(RevRevBlocked(), Reversible))
def test_Iterator(self):
non_samples = [None, 42, 3.14, 1j, b"", "", (), [], {}, set()]
@ -736,7 +781,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
iter(set()), iter(frozenset()),
iter(dict().keys()), iter(dict().items()),
iter(dict().values()),
(lambda: (yield))(),
_test_gen(),
(x for x in []),
]
for x in samples:
@ -824,7 +869,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
def test_Sized(self):
non_samples = [None, 42, 3.14, 1j,
(lambda: (yield))(),
_test_gen(),
(x for x in []),
]
for x in non_samples:
@ -842,7 +887,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
def test_Container(self):
non_samples = [None, 42, 3.14, 1j,
(lambda: (yield))(),
_test_gen(),
(x for x in []),
]
for x in non_samples:
@ -861,7 +906,7 @@ class TestOneTrickPonyABCs(ABCTestCase):
def test_Callable(self):
non_samples = [None, 42, 3.14, 1j,
"", b"", (), [], {}, set(),
(lambda: (yield))(),
_test_gen(),
(x for x in []),
]
for x in non_samples:
@ -1276,6 +1321,7 @@ class TestCollectionABCs(ABCTestCase):
def __iter__(self):
return iter(())
self.validate_comparison(MyMapping())
self.assertRaises(TypeError, reversed, MyMapping())
def test_MutableMapping(self):
for sample in [dict]:

View File

@ -84,6 +84,31 @@ class TestContains(unittest.TestCase):
self.assertTrue(container == constructor(values))
self.assertTrue(container == container)
def test_block_fallback(self):
# blocking fallback with __contains__ = None
class ByContains(object):
def __contains__(self, other):
return False
c = ByContains()
class BlockContains(ByContains):
"""Is not a container
This class is a perfectly good iterable (as tested by
list(bc)), as well as inheriting from a perfectly good
container, but __contains__ = None prevents the usual
fallback to iteration in the container protocol. That
is, normally, 0 in bc would fall back to the equivalent
of any(x==0 for x in bc), but here it's blocked from
doing so.
"""
def __iter__(self):
while False:
yield None
__contains__ = None
bc = BlockContains()
self.assertFalse(0 in c)
self.assertFalse(0 in list(bc))
self.assertRaises(TypeError, lambda: 0 in bc)
if __name__ == '__main__':
unittest.main()

View File

@ -223,7 +223,7 @@ class TestReversed(unittest.TestCase, PickleTest):
def test_objmethods(self):
# Objects must have __len__() and __getitem__() implemented.
class NoLen(object):
def __getitem__(self): return 1
def __getitem__(self, i): return 1
nl = NoLen()
self.assertRaises(TypeError, reversed, nl)
@ -232,6 +232,13 @@ class TestReversed(unittest.TestCase, PickleTest):
ngi = NoGetItem()
self.assertRaises(TypeError, reversed, ngi)
class Blocked(object):
def __getitem__(self, i): return 1
def __len__(self): return 2
__reversed__ = None
b = Blocked()
self.assertRaises(TypeError, reversed, b)
def test_pickle(self):
for data in 'abc', range(5), tuple(enumerate('abc')), range(1,17,5):
self.check_pickle(reversed(data), list(data)[::-1])

View File

@ -54,6 +54,14 @@ class UnlimitedSequenceClass:
def __getitem__(self, i):
return i
class DefaultIterClass:
pass
class NoIterClass:
def __getitem__(self, i):
return i
__iter__ = None
# Main test suite
class TestCase(unittest.TestCase):
@ -995,6 +1003,10 @@ class TestCase(unittest.TestCase):
def test_free_after_iterating(self):
check_free_after_iterating(self, iter, SequenceClass, (0,))
def test_error_iter(self):
for typ in (DefaultIterClass, NoIterClass):
self.assertRaises(TypeError, iter, typ())
def test_main():
run_unittest(TestCase)

View File

@ -986,6 +986,19 @@ class UnicodeTest(string_tests.CommonTest,
def __format__(self, format_spec):
return int.__format__(self * 2, format_spec)
class M:
def __init__(self, x):
self.x = x
def __repr__(self):
return 'M(' + self.x + ')'
__str__ = None
class N:
def __init__(self, x):
self.x = x
def __repr__(self):
return 'N(' + self.x + ')'
__format__ = None
self.assertEqual(''.format(), '')
self.assertEqual('abc'.format(), 'abc')
@ -1200,6 +1213,16 @@ class UnicodeTest(string_tests.CommonTest,
self.assertEqual("0x{:0{:d}X}".format(0x0,16), "0x0000000000000000")
# Blocking fallback
m = M('data')
self.assertEqual("{!r}".format(m), 'M(data)')
self.assertRaises(TypeError, "{!s}".format, m)
self.assertRaises(TypeError, "{}".format, m)
n = N('data')
self.assertEqual("{!r}".format(n), 'N(data)')
self.assertEqual("{!s}".format(n), 'N(data)')
self.assertRaises(TypeError, "{}".format, n)
def test_format_map(self):
self.assertEqual(''.format_map({}), '')
self.assertEqual('a'.format_map({}), 'a')

View File

@ -73,6 +73,14 @@ Core and Builtins
Library
-------
- Issue #25958: Support "anti-registration" of special methods from
various ABCs, like __hash__, __iter__ or __len__. All these (and
several more) can be set to None in an implementation class and the
behavior will be as if the method is not defined at all.
(Previously, this mechanism existed only for __hash__, to make
mutable classes unhashable.) Code contributed by Andrew Barnert and
Ivan Levkivskyi.
- Issue #16764: Support keyword arguments to zlib.decompress(). Patch by
Xiang Zhang.

View File

@ -250,6 +250,13 @@ reversed_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
return NULL;
reversed_meth = _PyObject_LookupSpecial(seq, &PyId___reversed__);
if (reversed_meth == Py_None) {
Py_DECREF(reversed_meth);
PyErr_Format(PyExc_TypeError,
"'%.200s' object is not reversible",
Py_TYPE(seq)->tp_name);
return NULL;
}
if (reversed_meth != NULL) {
PyObject *res = PyObject_CallFunctionObjArgs(reversed_meth, NULL);
Py_DECREF(reversed_meth);
@ -259,8 +266,9 @@ reversed_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
return NULL;
if (!PySequence_Check(seq)) {
PyErr_SetString(PyExc_TypeError,
"argument to reversed() must be a sequence");
PyErr_Format(PyExc_TypeError,
"'%.200s' object is not reversible",
Py_TYPE(seq)->tp_name);
return NULL;
}

View File

@ -5856,6 +5856,13 @@ slot_sq_contains(PyObject *self, PyObject *value)
_Py_IDENTIFIER(__contains__);
func = lookup_maybe(self, &PyId___contains__);
if (func == Py_None) {
Py_DECREF(func);
PyErr_Format(PyExc_TypeError,
"'%.200s' object is not a container",
Py_TYPE(self)->tp_name);
return -1;
}
if (func != NULL) {
args = PyTuple_Pack(1, value);
if (args == NULL)
@ -6241,6 +6248,13 @@ slot_tp_iter(PyObject *self)
_Py_IDENTIFIER(__iter__);
func = lookup_method(self, &PyId___iter__);
if (func == Py_None) {
Py_DECREF(func);
PyErr_Format(PyExc_TypeError,
"'%.200s' object is not iterable",
Py_TYPE(self)->tp_name);
return NULL;
}
if (func != NULL) {
PyObject *args;
args = res = PyTuple_New(0);