bpo-31581: Reduce the number of imports for functools (GH-3757)

This commit is contained in:
INADA Naoki 2017-09-30 16:13:02 +09:00 committed by GitHub
parent b24cd055ec
commit 9811e80fd0
2 changed files with 92 additions and 88 deletions

View File

@ -19,8 +19,7 @@ except ImportError:
pass pass
from abc import get_cache_token from abc import get_cache_token
from collections import namedtuple from collections import namedtuple
from types import MappingProxyType # import types, weakref # Deferred to single_dispatch()
from weakref import WeakKeyDictionary
from reprlib import recursive_repr from reprlib import recursive_repr
from _thread import RLock from _thread import RLock
@ -753,10 +752,14 @@ def singledispatch(func):
function acts as the default implementation, and additional function acts as the default implementation, and additional
implementations can be registered using the register() attribute of the implementations can be registered using the register() attribute of the
generic function. generic function.
""" """
# There are many programs that use functools without singledispatch, so we
# trade-off making singledispatch marginally slower for the benefit of
# making start-up of such applications slightly faster.
import types, weakref
registry = {} registry = {}
dispatch_cache = WeakKeyDictionary() dispatch_cache = weakref.WeakKeyDictionary()
cache_token = None cache_token = None
def dispatch(cls): def dispatch(cls):
@ -803,7 +806,7 @@ def singledispatch(func):
registry[object] = func registry[object] = func
wrapper.register = register wrapper.register = register
wrapper.dispatch = dispatch wrapper.dispatch = dispatch
wrapper.registry = MappingProxyType(registry) wrapper.registry = types.MappingProxyType(registry)
wrapper._clear_cache = dispatch_cache.clear wrapper._clear_cache = dispatch_cache.clear
update_wrapper(wrapper, func) update_wrapper(wrapper, func)
return wrapper return wrapper

View File

@ -2019,6 +2019,8 @@ class TestSingleDispatch(unittest.TestCase):
def test_cache_invalidation(self): def test_cache_invalidation(self):
from collections import UserDict from collections import UserDict
import weakref
class TracingDict(UserDict): class TracingDict(UserDict):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TracingDict, self).__init__(*args, **kwargs) super(TracingDict, self).__init__(*args, **kwargs)
@ -2033,90 +2035,89 @@ class TestSingleDispatch(unittest.TestCase):
self.data[key] = value self.data[key] = value
def clear(self): def clear(self):
self.data.clear() self.data.clear()
_orig_wkd = functools.WeakKeyDictionary
td = TracingDict() td = TracingDict()
functools.WeakKeyDictionary = lambda: td with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
c = collections.abc c = collections.abc
@functools.singledispatch @functools.singledispatch
def g(arg): def g(arg):
return "base" return "base"
d = {} d = {}
l = [] l = []
self.assertEqual(len(td), 0) self.assertEqual(len(td), 0)
self.assertEqual(g(d), "base") self.assertEqual(g(d), "base")
self.assertEqual(len(td), 1) self.assertEqual(len(td), 1)
self.assertEqual(td.get_ops, []) self.assertEqual(td.get_ops, [])
self.assertEqual(td.set_ops, [dict]) self.assertEqual(td.set_ops, [dict])
self.assertEqual(td.data[dict], g.registry[object]) self.assertEqual(td.data[dict], g.registry[object])
self.assertEqual(g(l), "base") self.assertEqual(g(l), "base")
self.assertEqual(len(td), 2) self.assertEqual(len(td), 2)
self.assertEqual(td.get_ops, []) self.assertEqual(td.get_ops, [])
self.assertEqual(td.set_ops, [dict, list]) self.assertEqual(td.set_ops, [dict, list])
self.assertEqual(td.data[dict], g.registry[object]) self.assertEqual(td.data[dict], g.registry[object])
self.assertEqual(td.data[list], g.registry[object]) self.assertEqual(td.data[list], g.registry[object])
self.assertEqual(td.data[dict], td.data[list]) self.assertEqual(td.data[dict], td.data[list])
self.assertEqual(g(l), "base") self.assertEqual(g(l), "base")
self.assertEqual(g(d), "base") self.assertEqual(g(d), "base")
self.assertEqual(td.get_ops, [list, dict]) self.assertEqual(td.get_ops, [list, dict])
self.assertEqual(td.set_ops, [dict, list]) self.assertEqual(td.set_ops, [dict, list])
g.register(list, lambda arg: "list") g.register(list, lambda arg: "list")
self.assertEqual(td.get_ops, [list, dict]) self.assertEqual(td.get_ops, [list, dict])
self.assertEqual(len(td), 0) self.assertEqual(len(td), 0)
self.assertEqual(g(d), "base") self.assertEqual(g(d), "base")
self.assertEqual(len(td), 1) self.assertEqual(len(td), 1)
self.assertEqual(td.get_ops, [list, dict]) self.assertEqual(td.get_ops, [list, dict])
self.assertEqual(td.set_ops, [dict, list, dict]) self.assertEqual(td.set_ops, [dict, list, dict])
self.assertEqual(td.data[dict], self.assertEqual(td.data[dict],
functools._find_impl(dict, g.registry)) functools._find_impl(dict, g.registry))
self.assertEqual(g(l), "list") self.assertEqual(g(l), "list")
self.assertEqual(len(td), 2) self.assertEqual(len(td), 2)
self.assertEqual(td.get_ops, [list, dict]) self.assertEqual(td.get_ops, [list, dict])
self.assertEqual(td.set_ops, [dict, list, dict, list]) self.assertEqual(td.set_ops, [dict, list, dict, list])
self.assertEqual(td.data[list], self.assertEqual(td.data[list],
functools._find_impl(list, g.registry)) functools._find_impl(list, g.registry))
class X: class X:
pass pass
c.MutableMapping.register(X) # Will not invalidate the cache, c.MutableMapping.register(X) # Will not invalidate the cache,
# not using ABCs yet. # not using ABCs yet.
self.assertEqual(g(d), "base") self.assertEqual(g(d), "base")
self.assertEqual(g(l), "list") self.assertEqual(g(l), "list")
self.assertEqual(td.get_ops, [list, dict, dict, list]) self.assertEqual(td.get_ops, [list, dict, dict, list])
self.assertEqual(td.set_ops, [dict, list, dict, list]) self.assertEqual(td.set_ops, [dict, list, dict, list])
g.register(c.Sized, lambda arg: "sized") g.register(c.Sized, lambda arg: "sized")
self.assertEqual(len(td), 0) self.assertEqual(len(td), 0)
self.assertEqual(g(d), "sized") self.assertEqual(g(d), "sized")
self.assertEqual(len(td), 1) self.assertEqual(len(td), 1)
self.assertEqual(td.get_ops, [list, dict, dict, list]) self.assertEqual(td.get_ops, [list, dict, dict, list])
self.assertEqual(td.set_ops, [dict, list, dict, list, dict]) self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
self.assertEqual(g(l), "list") self.assertEqual(g(l), "list")
self.assertEqual(len(td), 2) self.assertEqual(len(td), 2)
self.assertEqual(td.get_ops, [list, dict, dict, list]) self.assertEqual(td.get_ops, [list, dict, dict, list])
self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
self.assertEqual(g(l), "list") self.assertEqual(g(l), "list")
self.assertEqual(g(d), "sized") self.assertEqual(g(d), "sized")
self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict]) self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
g.dispatch(list) g.dispatch(list)
g.dispatch(dict) g.dispatch(dict)
self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict, self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
list, dict]) list, dict])
self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
c.MutableSet.register(X) # Will invalidate the cache. c.MutableSet.register(X) # Will invalidate the cache.
self.assertEqual(len(td), 2) # Stale cache. self.assertEqual(len(td), 2) # Stale cache.
self.assertEqual(g(l), "list") self.assertEqual(g(l), "list")
self.assertEqual(len(td), 1) self.assertEqual(len(td), 1)
g.register(c.MutableMapping, lambda arg: "mutablemapping") g.register(c.MutableMapping, lambda arg: "mutablemapping")
self.assertEqual(len(td), 0) self.assertEqual(len(td), 0)
self.assertEqual(g(d), "mutablemapping") self.assertEqual(g(d), "mutablemapping")
self.assertEqual(len(td), 1) self.assertEqual(len(td), 1)
self.assertEqual(g(l), "list") self.assertEqual(g(l), "list")
self.assertEqual(len(td), 2) self.assertEqual(len(td), 2)
g.register(dict, lambda arg: "dict") g.register(dict, lambda arg: "dict")
self.assertEqual(g(d), "dict") self.assertEqual(g(d), "dict")
self.assertEqual(g(l), "list") self.assertEqual(g(l), "list")
g._clear_cache() g._clear_cache()
self.assertEqual(len(td), 0) self.assertEqual(len(td), 0)
functools.WeakKeyDictionary = _orig_wkd
if __name__ == '__main__': if __name__ == '__main__':