Issue #18244: Adopt C3-based linearization in functools.singledispatch for improved ABC support
This commit is contained in:
parent
04926aeb2f
commit
3720c77e30
178
Lib/functools.py
178
Lib/functools.py
|
@ -365,46 +365,138 @@ def lru_cache(maxsize=128, typed=False):
|
|||
### singledispatch() - single-dispatch generic function decorator
|
||||
################################################################################
|
||||
|
||||
def _compose_mro(cls, haystack):
|
||||
"""Calculates the MRO for a given class `cls`, including relevant abstract
|
||||
base classes from `haystack`.
|
||||
def _c3_merge(sequences):
|
||||
"""Merges MROs in *sequences* to a single MRO using the C3 algorithm.
|
||||
|
||||
Adapted from http://www.python.org/download/releases/2.3/mro/.
|
||||
|
||||
"""
|
||||
result = []
|
||||
while True:
|
||||
sequences = [s for s in sequences if s] # purge empty sequences
|
||||
if not sequences:
|
||||
return result
|
||||
for s1 in sequences: # find merge candidates among seq heads
|
||||
candidate = s1[0]
|
||||
for s2 in sequences:
|
||||
if candidate in s2[1:]:
|
||||
candidate = None
|
||||
break # reject the current head, it appears later
|
||||
else:
|
||||
break
|
||||
if not candidate:
|
||||
raise RuntimeError("Inconsistent hierarchy")
|
||||
result.append(candidate)
|
||||
# remove the chosen candidate
|
||||
for seq in sequences:
|
||||
if seq[0] == candidate:
|
||||
del seq[0]
|
||||
|
||||
def _c3_mro(cls, abcs=None):
|
||||
"""Computes the method resolution order using extended C3 linearization.
|
||||
|
||||
If no *abcs* are given, the algorithm works exactly like the built-in C3
|
||||
linearization used for method resolution.
|
||||
|
||||
If given, *abcs* is a list of abstract base classes that should be inserted
|
||||
into the resulting MRO. Unrelated ABCs are ignored and don't end up in the
|
||||
result. The algorithm inserts ABCs where their functionality is introduced,
|
||||
i.e. issubclass(cls, abc) returns True for the class itself but returns
|
||||
False for all its direct base classes. Implicit ABCs for a given class
|
||||
(either registered or inferred from the presence of a special method like
|
||||
__len__) are inserted directly after the last ABC explicitly listed in the
|
||||
MRO of said class. If two implicit ABCs end up next to each other in the
|
||||
resulting MRO, their ordering depends on the order of types in *abcs*.
|
||||
|
||||
"""
|
||||
for i, base in enumerate(reversed(cls.__bases__)):
|
||||
if hasattr(base, '__abstractmethods__'):
|
||||
boundary = len(cls.__bases__) - i
|
||||
break # Bases up to the last explicit ABC are considered first.
|
||||
else:
|
||||
boundary = 0
|
||||
abcs = list(abcs) if abcs else []
|
||||
explicit_bases = list(cls.__bases__[:boundary])
|
||||
abstract_bases = []
|
||||
other_bases = list(cls.__bases__[boundary:])
|
||||
for base in abcs:
|
||||
if issubclass(cls, base) and not any(
|
||||
issubclass(b, base) for b in cls.__bases__
|
||||
):
|
||||
# If *cls* is the class that introduces behaviour described by
|
||||
# an ABC *base*, insert said ABC to its MRO.
|
||||
abstract_bases.append(base)
|
||||
for base in abstract_bases:
|
||||
abcs.remove(base)
|
||||
explicit_c3_mros = [_c3_mro(base, abcs=abcs) for base in explicit_bases]
|
||||
abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in abstract_bases]
|
||||
other_c3_mros = [_c3_mro(base, abcs=abcs) for base in other_bases]
|
||||
return _c3_merge(
|
||||
[[cls]] +
|
||||
explicit_c3_mros + abstract_c3_mros + other_c3_mros +
|
||||
[explicit_bases] + [abstract_bases] + [other_bases]
|
||||
)
|
||||
|
||||
def _compose_mro(cls, types):
|
||||
"""Calculates the method resolution order for a given class *cls*.
|
||||
|
||||
Includes relevant abstract base classes (with their respective bases) from
|
||||
the *types* iterable. Uses a modified C3 linearization algorithm.
|
||||
|
||||
"""
|
||||
bases = set(cls.__mro__)
|
||||
mro = list(cls.__mro__)
|
||||
for needle in haystack:
|
||||
if (needle in bases or not hasattr(needle, '__mro__')
|
||||
or not issubclass(cls, needle)):
|
||||
continue # either present in the __mro__ already or unrelated
|
||||
for index, base in enumerate(mro):
|
||||
if not issubclass(base, needle):
|
||||
break
|
||||
if base in bases and not issubclass(needle, base):
|
||||
# Conflict resolution: put classes present in __mro__ and their
|
||||
# subclasses first. See test_mro_conflicts() in test_functools.py
|
||||
# for examples.
|
||||
index += 1
|
||||
mro.insert(index, needle)
|
||||
return mro
|
||||
# Remove entries which are already present in the __mro__ or unrelated.
|
||||
def is_related(typ):
|
||||
return (typ not in bases and hasattr(typ, '__mro__')
|
||||
and issubclass(cls, typ))
|
||||
types = [n for n in types if is_related(n)]
|
||||
# Remove entries which are strict bases of other entries (they will end up
|
||||
# in the MRO anyway.
|
||||
def is_strict_base(typ):
|
||||
for other in types:
|
||||
if typ != other and typ in other.__mro__:
|
||||
return True
|
||||
return False
|
||||
types = [n for n in types if not is_strict_base(n)]
|
||||
# Subclasses of the ABCs in *types* which are also implemented by
|
||||
# *cls* can be used to stabilize ABC ordering.
|
||||
type_set = set(types)
|
||||
mro = []
|
||||
for typ in types:
|
||||
found = []
|
||||
for sub in typ.__subclasses__():
|
||||
if sub not in bases and issubclass(cls, sub):
|
||||
found.append([s for s in sub.__mro__ if s in type_set])
|
||||
if not found:
|
||||
mro.append(typ)
|
||||
continue
|
||||
# Favor subclasses with the biggest number of useful bases
|
||||
found.sort(key=len, reverse=True)
|
||||
for sub in found:
|
||||
for subcls in sub:
|
||||
if subcls not in mro:
|
||||
mro.append(subcls)
|
||||
return _c3_mro(cls, abcs=mro)
|
||||
|
||||
def _find_impl(cls, registry):
|
||||
"""Returns the best matching implementation for the given class `cls` in
|
||||
`registry`. Where there is no registered implementation for a specific
|
||||
type, its method resolution order is used to find a more generic
|
||||
implementation.
|
||||
"""Returns the best matching implementation from *registry* for type *cls*.
|
||||
|
||||
Note: if `registry` does not contain an implementation for the base
|
||||
`object` type, this function may return None.
|
||||
Where there is no registered implementation for a specific type, its method
|
||||
resolution order is used to find a more generic implementation.
|
||||
|
||||
Note: if *registry* does not contain an implementation for the base
|
||||
*object* type, this function may return None.
|
||||
|
||||
"""
|
||||
mro = _compose_mro(cls, registry.keys())
|
||||
match = None
|
||||
for t in mro:
|
||||
if match is not None:
|
||||
# If `match` is an ABC but there is another unrelated, equally
|
||||
# matching ABC. Refuse the temptation to guess.
|
||||
if (t in registry and not issubclass(match, t)
|
||||
and match not in cls.__mro__):
|
||||
# If *match* is an implicit ABC but there is another unrelated,
|
||||
# equally matching implicit ABC, refuse the temptation to guess.
|
||||
if (t in registry and t not in cls.__mro__
|
||||
and match not in cls.__mro__
|
||||
and not issubclass(match, t)):
|
||||
raise RuntimeError("Ambiguous dispatch: {} or {}".format(
|
||||
match, t))
|
||||
break
|
||||
|
@ -418,19 +510,19 @@ def singledispatch(func):
|
|||
Transforms a function into a generic function, which can have different
|
||||
behaviours depending upon the type of its first argument. The decorated
|
||||
function acts as the default implementation, and additional
|
||||
implementations can be registered using the 'register()' attribute of
|
||||
the generic function.
|
||||
implementations can be registered using the register() attribute of the
|
||||
generic function.
|
||||
|
||||
"""
|
||||
registry = {}
|
||||
dispatch_cache = WeakKeyDictionary()
|
||||
cache_token = None
|
||||
|
||||
def dispatch(typ):
|
||||
"""generic_func.dispatch(type) -> <function implementation>
|
||||
def dispatch(cls):
|
||||
"""generic_func.dispatch(cls) -> <function implementation>
|
||||
|
||||
Runs the dispatch algorithm to return the best available implementation
|
||||
for the given `type` registered on `generic_func`.
|
||||
for the given *cls* registered on *generic_func*.
|
||||
|
||||
"""
|
||||
nonlocal cache_token
|
||||
|
@ -440,26 +532,26 @@ def singledispatch(func):
|
|||
dispatch_cache.clear()
|
||||
cache_token = current_token
|
||||
try:
|
||||
impl = dispatch_cache[typ]
|
||||
impl = dispatch_cache[cls]
|
||||
except KeyError:
|
||||
try:
|
||||
impl = registry[typ]
|
||||
impl = registry[cls]
|
||||
except KeyError:
|
||||
impl = _find_impl(typ, registry)
|
||||
dispatch_cache[typ] = impl
|
||||
impl = _find_impl(cls, registry)
|
||||
dispatch_cache[cls] = impl
|
||||
return impl
|
||||
|
||||
def register(typ, func=None):
|
||||
"""generic_func.register(type, func) -> func
|
||||
def register(cls, func=None):
|
||||
"""generic_func.register(cls, func) -> func
|
||||
|
||||
Registers a new implementation for the given `type` on a `generic_func`.
|
||||
Registers a new implementation for the given *cls* on a *generic_func*.
|
||||
|
||||
"""
|
||||
nonlocal cache_token
|
||||
if func is None:
|
||||
return lambda f: register(typ, f)
|
||||
registry[typ] = func
|
||||
if cache_token is None and hasattr(typ, '__abstractmethods__'):
|
||||
return lambda f: register(cls, f)
|
||||
registry[cls] = func
|
||||
if cache_token is None and hasattr(cls, '__abstractmethods__'):
|
||||
cache_token = get_cache_token()
|
||||
dispatch_cache.clear()
|
||||
return func
|
||||
|
|
|
@ -929,22 +929,55 @@ class TestSingleDispatch(unittest.TestCase):
|
|||
self.assertEqual(g(rnd), ("Number got rounded",))
|
||||
|
||||
def test_compose_mro(self):
|
||||
# None of the examples in this test depend on haystack ordering.
|
||||
c = collections
|
||||
mro = functools._compose_mro
|
||||
bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
|
||||
for haystack in permutations(bases):
|
||||
m = mro(dict, haystack)
|
||||
self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, object])
|
||||
self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
|
||||
c.Iterable, c.Container, object])
|
||||
bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
|
||||
for haystack in permutations(bases):
|
||||
m = mro(c.ChainMap, haystack)
|
||||
self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
|
||||
c.Sized, c.Iterable, c.Container, object])
|
||||
# Note: The MRO order below depends on haystack ordering.
|
||||
m = mro(c.defaultdict, [c.Sized, c.Container, str])
|
||||
self.assertEqual(m, [c.defaultdict, dict, c.Container, c.Sized, object])
|
||||
m = mro(c.defaultdict, [c.Container, c.Sized, str])
|
||||
self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, object])
|
||||
|
||||
# If there's a generic function with implementations registered for
|
||||
# both Sized and Container, passing a defaultdict to it results in an
|
||||
# ambiguous dispatch which will cause a RuntimeError (see
|
||||
# test_mro_conflicts).
|
||||
bases = [c.Container, c.Sized, str]
|
||||
for haystack in permutations(bases):
|
||||
m = mro(c.defaultdict, [c.Sized, c.Container, str])
|
||||
self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
|
||||
object])
|
||||
|
||||
# MutableSequence below is registered directly on D. In other words, it
|
||||
# preceeds MutableMapping which means single dispatch will always
|
||||
# choose MutableSequence here.
|
||||
class D(c.defaultdict):
|
||||
pass
|
||||
c.MutableSequence.register(D)
|
||||
bases = [c.MutableSequence, c.MutableMapping]
|
||||
for haystack in permutations(bases):
|
||||
m = mro(D, bases)
|
||||
self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
|
||||
c.defaultdict, dict, c.MutableMapping,
|
||||
c.Mapping, c.Sized, c.Iterable, c.Container,
|
||||
object])
|
||||
|
||||
# Container and Callable are registered on different base classes and
|
||||
# a generic function supporting both should always pick the Callable
|
||||
# implementation if a C instance is passed.
|
||||
class C(c.defaultdict):
|
||||
def __call__(self):
|
||||
pass
|
||||
bases = [c.Sized, c.Callable, c.Container, c.Mapping]
|
||||
for haystack in permutations(bases):
|
||||
m = mro(C, haystack)
|
||||
self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
|
||||
c.Sized, c.Iterable, c.Container, object])
|
||||
|
||||
def test_register_abc(self):
|
||||
c = collections
|
||||
|
@ -1040,17 +1073,37 @@ class TestSingleDispatch(unittest.TestCase):
|
|||
self.assertEqual(g(f), "frozen-set")
|
||||
self.assertEqual(g(t), "tuple")
|
||||
|
||||
def test_c3_abc(self):
|
||||
c = collections
|
||||
mro = functools._c3_mro
|
||||
class A(object):
|
||||
pass
|
||||
class B(A):
|
||||
def __len__(self):
|
||||
return 0 # implies Sized
|
||||
@c.Container.register
|
||||
class C(object):
|
||||
pass
|
||||
class D(object):
|
||||
pass # unrelated
|
||||
class X(D, C, B):
|
||||
def __call__(self):
|
||||
pass # implies Callable
|
||||
expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
|
||||
for abcs in permutations([c.Sized, c.Callable, c.Container]):
|
||||
self.assertEqual(mro(X, abcs=abcs), expected)
|
||||
# unrelated ABCs don't appear in the resulting MRO
|
||||
many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
|
||||
self.assertEqual(mro(X, abcs=many_abcs), expected)
|
||||
|
||||
def test_mro_conflicts(self):
|
||||
c = collections
|
||||
|
||||
@functools.singledispatch
|
||||
def g(arg):
|
||||
return "base"
|
||||
|
||||
class O(c.Sized):
|
||||
def __len__(self):
|
||||
return 0
|
||||
|
||||
o = O()
|
||||
self.assertEqual(g(o), "base")
|
||||
g.register(c.Iterable, lambda arg: "iterable")
|
||||
|
@ -1062,35 +1115,114 @@ class TestSingleDispatch(unittest.TestCase):
|
|||
self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
|
||||
c.Container.register(O)
|
||||
self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
|
||||
|
||||
c.Set.register(O)
|
||||
self.assertEqual(g(o), "set") # because c.Set is a subclass of
|
||||
# c.Sized and c.Container
|
||||
class P:
|
||||
pass
|
||||
|
||||
p = P()
|
||||
self.assertEqual(g(p), "base")
|
||||
c.Iterable.register(P)
|
||||
self.assertEqual(g(p), "iterable")
|
||||
c.Container.register(P)
|
||||
with self.assertRaises(RuntimeError) as re:
|
||||
with self.assertRaises(RuntimeError) as re_one:
|
||||
g(p)
|
||||
self.assertEqual(
|
||||
str(re),
|
||||
("Ambiguous dispatch: <class 'collections.abc.Container'> "
|
||||
"or <class 'collections.abc.Iterable'>"),
|
||||
)
|
||||
|
||||
self.assertIn(
|
||||
str(re_one.exception),
|
||||
(("Ambiguous dispatch: <class 'collections.abc.Container'> "
|
||||
"or <class 'collections.abc.Iterable'>"),
|
||||
("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
|
||||
"or <class 'collections.abc.Container'>")),
|
||||
)
|
||||
class Q(c.Sized):
|
||||
def __len__(self):
|
||||
return 0
|
||||
|
||||
q = Q()
|
||||
self.assertEqual(g(q), "sized")
|
||||
c.Iterable.register(Q)
|
||||
self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
|
||||
c.Set.register(Q)
|
||||
self.assertEqual(g(q), "set") # because c.Set is a subclass of
|
||||
# c.Sized which is explicitly in
|
||||
# __mro__
|
||||
# c.Sized and c.Iterable
|
||||
@functools.singledispatch
|
||||
def h(arg):
|
||||
return "base"
|
||||
@h.register(c.Sized)
|
||||
def _(arg):
|
||||
return "sized"
|
||||
@h.register(c.Container)
|
||||
def _(arg):
|
||||
return "container"
|
||||
# Even though Sized and Container are explicit bases of MutableMapping,
|
||||
# this ABC is implicitly registered on defaultdict which makes all of
|
||||
# MutableMapping's bases implicit as well from defaultdict's
|
||||
# perspective.
|
||||
with self.assertRaises(RuntimeError) as re_two:
|
||||
h(c.defaultdict(lambda: 0))
|
||||
self.assertIn(
|
||||
str(re_two.exception),
|
||||
(("Ambiguous dispatch: <class 'collections.abc.Container'> "
|
||||
"or <class 'collections.abc.Sized'>"),
|
||||
("Ambiguous dispatch: <class 'collections.abc.Sized'> "
|
||||
"or <class 'collections.abc.Container'>")),
|
||||
)
|
||||
class R(c.defaultdict):
|
||||
pass
|
||||
c.MutableSequence.register(R)
|
||||
@functools.singledispatch
|
||||
def i(arg):
|
||||
return "base"
|
||||
@i.register(c.MutableMapping)
|
||||
def _(arg):
|
||||
return "mapping"
|
||||
@i.register(c.MutableSequence)
|
||||
def _(arg):
|
||||
return "sequence"
|
||||
r = R()
|
||||
self.assertEqual(i(r), "sequence")
|
||||
class S:
|
||||
pass
|
||||
class T(S, c.Sized):
|
||||
def __len__(self):
|
||||
return 0
|
||||
t = T()
|
||||
self.assertEqual(h(t), "sized")
|
||||
c.Container.register(T)
|
||||
self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
|
||||
class U:
|
||||
def __len__(self):
|
||||
return 0
|
||||
u = U()
|
||||
self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
|
||||
# from the existence of __len__()
|
||||
c.Container.register(U)
|
||||
# There is no preference for registered versus inferred ABCs.
|
||||
with self.assertRaises(RuntimeError) as re_three:
|
||||
h(u)
|
||||
self.assertIn(
|
||||
str(re_three.exception),
|
||||
(("Ambiguous dispatch: <class 'collections.abc.Container'> "
|
||||
"or <class 'collections.abc.Sized'>"),
|
||||
("Ambiguous dispatch: <class 'collections.abc.Sized'> "
|
||||
"or <class 'collections.abc.Container'>")),
|
||||
)
|
||||
class V(c.Sized, S):
|
||||
def __len__(self):
|
||||
return 0
|
||||
@functools.singledispatch
|
||||
def j(arg):
|
||||
return "base"
|
||||
@j.register(S)
|
||||
def _(arg):
|
||||
return "s"
|
||||
@j.register(c.Container)
|
||||
def _(arg):
|
||||
return "container"
|
||||
v = V()
|
||||
self.assertEqual(j(v), "s")
|
||||
c.Container.register(V)
|
||||
self.assertEqual(j(v), "container") # because it ends up right after
|
||||
# Sized in the MRO
|
||||
|
||||
def test_cache_invalidation(self):
|
||||
from collections import UserDict
|
||||
|
|
Loading…
Reference in New Issue