import abc import collections import copy from itertools import permutations import pickle from random import choice import sys from test import support import unittest from weakref import proxy import contextlib try: import threading except ImportError: threading = None import functools py_functools = support.import_fresh_module('functools', blocked=['_functools']) c_functools = support.import_fresh_module('functools', fresh=['_functools']) decimal = support.import_fresh_module('decimal', fresh=['_decimal']) @contextlib.contextmanager def replaced_module(name, replacement): original_module = sys.modules[name] sys.modules[name] = replacement try: yield finally: sys.modules[name] = original_module def capture(*args, **kw): """capture all positional and keyword arguments""" return args, kw def signature(part): """ return the signature of a partial object """ return (part.func, part.args, part.keywords, part.__dict__) class MyTuple(tuple): pass class BadTuple(tuple): def __add__(self, other): return list(self) + list(other) class MyDict(dict): pass class TestPartial: def test_basic_examples(self): p = self.partial(capture, 1, 2, a=10, b=20) self.assertTrue(callable(p)) self.assertEqual(p(3, 4, b=30, c=40), ((1, 2, 3, 4), dict(a=10, b=30, c=40))) p = self.partial(map, lambda x: x*10) self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40]) def test_attributes(self): p = self.partial(capture, 1, 2, a=10, b=20) # attributes should be readable self.assertEqual(p.func, capture) self.assertEqual(p.args, (1, 2)) self.assertEqual(p.keywords, dict(a=10, b=20)) def test_argument_checking(self): self.assertRaises(TypeError, self.partial) # need at least a func arg try: self.partial(2)() except TypeError: pass else: self.fail('First arg not checked for callability') def test_protection_of_callers_dict_argument(self): # a caller's dictionary should not be altered by partial def func(a=10, b=20): return a d = {'a':3} p = self.partial(func, a=5) self.assertEqual(p(**d), 3) self.assertEqual(d, {'a':3}) p(b=7) self.assertEqual(d, {'a':3}) def test_arg_combinations(self): # exercise special code paths for zero args in either partial # object or the caller p = self.partial(capture) self.assertEqual(p(), ((), {})) self.assertEqual(p(1,2), ((1,2), {})) p = self.partial(capture, 1, 2) self.assertEqual(p(), ((1,2), {})) self.assertEqual(p(3,4), ((1,2,3,4), {})) def test_kw_combinations(self): # exercise special code paths for no keyword args in # either the partial object or the caller p = self.partial(capture) self.assertEqual(p.keywords, {}) self.assertEqual(p(), ((), {})) self.assertEqual(p(a=1), ((), {'a':1})) p = self.partial(capture, a=1) self.assertEqual(p.keywords, {'a':1}) self.assertEqual(p(), ((), {'a':1})) self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) # keyword args in the call override those in the partial object self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2})) def test_positional(self): # make sure positional arguments are captured correctly for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: p = self.partial(capture, *args) expected = args + ('x',) got, empty = p('x') self.assertTrue(expected == got and empty == {}) def test_keyword(self): # make sure keyword arguments are captured correctly for a in ['a', 0, None, 3.5]: p = self.partial(capture, a=a) expected = {'a':a,'x':None} empty, got = p(x=None) self.assertTrue(expected == got and empty == ()) def test_no_side_effects(self): # make sure there are no side effects that affect subsequent calls p = self.partial(capture, 0, a=1) args1, kw1 = p(1, b=2) self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2}) args2, kw2 = p() self.assertTrue(args2 == (0,) and kw2 == {'a':1}) def test_error_propagation(self): def f(x, y): x / y self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0)) self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0) self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0) self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1) def test_weakref(self): f = self.partial(int, base=16) p = proxy(f) self.assertEqual(f.func, p.func) f = None self.assertRaises(ReferenceError, getattr, p, 'func') def test_with_bound_and_unbound_methods(self): data = list(map(str, range(10))) join = self.partial(str.join, '') self.assertEqual(join(data), '0123456789') join = self.partial(''.join) self.assertEqual(join(data), '0123456789') def test_nested_optimization(self): partial = self.partial inner = partial(signature, 'asdf') nested = partial(inner, bar=True) flat = partial(signature, 'asdf', bar=True) self.assertEqual(signature(nested), signature(flat)) def test_nested_partial_with_attribute(self): # see issue 25137 partial = self.partial def foo(bar): return bar p = partial(foo, 'first') p2 = partial(p, 'second') p2.new_attr = 'spam' self.assertEqual(p2.new_attr, 'spam') def test_repr(self): args = (object(), object()) args_repr = ', '.join(repr(a) for a in args) kwargs = {'a': object(), 'b': object()} kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs), 'b={b!r}, a={a!r}'.format_map(kwargs)] if self.partial in (c_functools.partial, py_functools.partial): name = 'functools.partial' else: name = self.partial.__name__ f = self.partial(capture) self.assertEqual(f'{name}({capture!r})', repr(f)) f = self.partial(capture, *args) self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f)) f = self.partial(capture, **kwargs) self.assertIn(repr(f), [f'{name}({capture!r}, {kwargs_repr})' for kwargs_repr in kwargs_reprs]) f = self.partial(capture, *args, **kwargs) self.assertIn(repr(f), [f'{name}({capture!r}, {args_repr}, {kwargs_repr})' for kwargs_repr in kwargs_reprs]) def test_recursive_repr(self): if self.partial in (c_functools.partial, py_functools.partial): name = 'functools.partial' else: name = self.partial.__name__ f = self.partial(capture) f.__setstate__((f, (), {}, {})) try: self.assertEqual(repr(f), '%s(...)' % (name,)) finally: f.__setstate__((capture, (), {}, {})) f = self.partial(capture) f.__setstate__((capture, (f,), {}, {})) try: self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,)) finally: f.__setstate__((capture, (), {}, {})) f = self.partial(capture) f.__setstate__((capture, (), {'a': f}, {})) try: self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,)) finally: f.__setstate__((capture, (), {}, {})) def test_pickle(self): with self.AllowPickle(): f = self.partial(signature, ['asdf'], bar=[True]) f.attr = [] for proto in range(pickle.HIGHEST_PROTOCOL + 1): f_copy = pickle.loads(pickle.dumps(f, proto)) self.assertEqual(signature(f_copy), signature(f)) def test_copy(self): f = self.partial(signature, ['asdf'], bar=[True]) f.attr = [] f_copy = copy.copy(f) self.assertEqual(signature(f_copy), signature(f)) self.assertIs(f_copy.attr, f.attr) self.assertIs(f_copy.args, f.args) self.assertIs(f_copy.keywords, f.keywords) def test_deepcopy(self): f = self.partial(signature, ['asdf'], bar=[True]) f.attr = [] f_copy = copy.deepcopy(f) self.assertEqual(signature(f_copy), signature(f)) self.assertIsNot(f_copy.attr, f.attr) self.assertIsNot(f_copy.args, f.args) self.assertIsNot(f_copy.args[0], f.args[0]) self.assertIsNot(f_copy.keywords, f.keywords) self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar']) def test_setstate(self): f = self.partial(signature) f.__setstate__((capture, (1,), dict(a=10), dict(attr=[]))) self.assertEqual(signature(f), (capture, (1,), dict(a=10), dict(attr=[]))) self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) f.__setstate__((capture, (1,), dict(a=10), None)) self.assertEqual(signature(f), (capture, (1,), dict(a=10), {})) self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) f.__setstate__((capture, (1,), None, None)) #self.assertEqual(signature(f), (capture, (1,), {}, {})) self.assertEqual(f(2, b=20), ((1, 2), {'b': 20})) self.assertEqual(f(2), ((1, 2), {})) self.assertEqual(f(), ((1,), {})) f.__setstate__((capture, (), {}, None)) self.assertEqual(signature(f), (capture, (), {}, {})) self.assertEqual(f(2, b=20), ((2,), {'b': 20})) self.assertEqual(f(2), ((2,), {})) self.assertEqual(f(), ((), {})) def test_setstate_errors(self): f = self.partial(signature) self.assertRaises(TypeError, f.__setstate__, (capture, (), {})) self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None)) self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None]) self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None)) self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None)) self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None)) self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None)) def test_setstate_subclasses(self): f = self.partial(signature) f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None)) s = signature(f) self.assertEqual(s, (capture, (1,), dict(a=10), {})) self.assertIs(type(s[1]), tuple) self.assertIs(type(s[2]), dict) r = f() self.assertEqual(r, ((1,), {'a': 10})) self.assertIs(type(r[0]), tuple) self.assertIs(type(r[1]), dict) f.__setstate__((capture, BadTuple((1,)), {}, None)) s = signature(f) self.assertEqual(s, (capture, (1,), {}, {})) self.assertIs(type(s[1]), tuple) r = f(2) self.assertEqual(r, ((1, 2), {})) self.assertIs(type(r[0]), tuple) def test_recursive_pickle(self): with self.AllowPickle(): f = self.partial(capture) f.__setstate__((f, (), {}, {})) try: for proto in range(pickle.HIGHEST_PROTOCOL + 1): with self.assertRaises(RecursionError): pickle.dumps(f, proto) finally: f.__setstate__((capture, (), {}, {})) f = self.partial(capture) f.__setstate__((capture, (f,), {}, {})) try: for proto in range(pickle.HIGHEST_PROTOCOL + 1): f_copy = pickle.loads(pickle.dumps(f, proto)) try: self.assertIs(f_copy.args[0], f_copy) finally: f_copy.__setstate__((capture, (), {}, {})) finally: f.__setstate__((capture, (), {}, {})) f = self.partial(capture) f.__setstate__((capture, (), {'a': f}, {})) try: for proto in range(pickle.HIGHEST_PROTOCOL + 1): f_copy = pickle.loads(pickle.dumps(f, proto)) try: self.assertIs(f_copy.keywords['a'], f_copy) finally: f_copy.__setstate__((capture, (), {}, {})) finally: f.__setstate__((capture, (), {}, {})) # Issue 6083: Reference counting bug def test_setstate_refcount(self): class BadSequence: def __len__(self): return 4 def __getitem__(self, key): if key == 0: return max elif key == 1: return tuple(range(1000000)) elif key in (2, 3): return {} raise IndexError f = self.partial(object) self.assertRaises(TypeError, f.__setstate__, BadSequence()) @unittest.skipUnless(c_functools, 'requires the C _functools module') class TestPartialC(TestPartial, unittest.TestCase): if c_functools: partial = c_functools.partial class AllowPickle: def __enter__(self): return self def __exit__(self, type, value, tb): return False def test_attributes_unwritable(self): # attributes should not be writable p = self.partial(capture, 1, 2, a=10, b=20) self.assertRaises(AttributeError, setattr, p, 'func', map) self.assertRaises(AttributeError, setattr, p, 'args', (1, 2)) self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2)) p = self.partial(hex) try: del p.__dict__ except TypeError: pass else: self.fail('partial object allowed __dict__ to be deleted') class TestPartialPy(TestPartial, unittest.TestCase): partial = py_functools.partial class AllowPickle: def __init__(self): self._cm = replaced_module("functools", py_functools) def __enter__(self): return self._cm.__enter__() def __exit__(self, type, value, tb): return self._cm.__exit__(type, value, tb) if c_functools: class CPartialSubclass(c_functools.partial): pass class PyPartialSubclass(py_functools.partial): pass @unittest.skipUnless(c_functools, 'requires the C _functools module') class TestPartialCSubclass(TestPartialC): if c_functools: partial = CPartialSubclass # partial subclasses are not optimized for nested calls test_nested_optimization = None class TestPartialPySubclass(TestPartialPy): partial = PyPartialSubclass class TestPartialMethod(unittest.TestCase): class A(object): nothing = functools.partialmethod(capture) positional = functools.partialmethod(capture, 1) keywords = functools.partialmethod(capture, a=2) both = functools.partialmethod(capture, 3, b=4) nested = functools.partialmethod(positional, 5) over_partial = functools.partialmethod(functools.partial(capture, c=6), 7) static = functools.partialmethod(staticmethod(capture), 8) cls = functools.partialmethod(classmethod(capture), d=9) a = A() def test_arg_combinations(self): self.assertEqual(self.a.nothing(), ((self.a,), {})) self.assertEqual(self.a.nothing(5), ((self.a, 5), {})) self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6})) self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6})) self.assertEqual(self.a.positional(), ((self.a, 1), {})) self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {})) self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6})) self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6})) self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2})) self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2})) self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6})) self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6})) self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4})) self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4})) self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6})) self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) def test_nested(self): self.assertEqual(self.a.nested(), ((self.a, 1, 5), {})) self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {})) self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7})) self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7})) self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7})) def test_over_partial(self): self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6})) self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6})) self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8})) self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) def test_bound_method_introspection(self): obj = self.a self.assertIs(obj.both.__self__, obj) self.assertIs(obj.nested.__self__, obj) self.assertIs(obj.over_partial.__self__, obj) self.assertIs(obj.cls.__self__, self.A) self.assertIs(self.A.cls.__self__, self.A) def test_unbound_method_retrieval(self): obj = self.A self.assertFalse(hasattr(obj.both, "__self__")) self.assertFalse(hasattr(obj.nested, "__self__")) self.assertFalse(hasattr(obj.over_partial, "__self__")) self.assertFalse(hasattr(obj.static, "__self__")) self.assertFalse(hasattr(self.a.static, "__self__")) def test_descriptors(self): for obj in [self.A, self.a]: with self.subTest(obj=obj): self.assertEqual(obj.static(), ((8,), {})) self.assertEqual(obj.static(5), ((8, 5), {})) self.assertEqual(obj.static(d=8), ((8,), {'d': 8})) self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8})) self.assertEqual(obj.cls(), ((self.A,), {'d': 9})) self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9})) self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9})) self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9})) def test_overriding_keywords(self): self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3})) self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3})) def test_invalid_args(self): with self.assertRaises(TypeError): class B(object): method = functools.partialmethod(None, 1) def test_repr(self): self.assertEqual(repr(vars(self.A)['both']), 'functools.partialmethod({}, 3, b=4)'.format(capture)) def test_abstract(self): class Abstract(abc.ABCMeta): @abc.abstractmethod def add(self, x, y): pass add5 = functools.partialmethod(add, 5) self.assertTrue(Abstract.add.__isabstractmethod__) self.assertTrue(Abstract.add5.__isabstractmethod__) for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]: self.assertFalse(getattr(func, '__isabstractmethod__', False)) class TestUpdateWrapper(unittest.TestCase): def check_wrapper(self, wrapper, wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, updated=functools.WRAPPER_UPDATES): # Check attributes were assigned for name in assigned: self.assertIs(getattr(wrapper, name), getattr(wrapped, name)) # Check attributes were updated for name in updated: wrapper_attr = getattr(wrapper, name) wrapped_attr = getattr(wrapped, name) for key in wrapped_attr: if name == "__dict__" and key == "__wrapped__": # __wrapped__ is overwritten by the update code continue self.assertIs(wrapped_attr[key], wrapper_attr[key]) # Check __wrapped__ self.assertIs(wrapper.__wrapped__, wrapped) def _default_update(self): def f(a:'This is a new annotation'): """This is a test""" pass f.attr = 'This is also a test' f.__wrapped__ = "This is a bald faced lie" def wrapper(b:'This is the prior annotation'): pass functools.update_wrapper(wrapper, f) return wrapper, f def test_default_update(self): wrapper, f = self._default_update() self.check_wrapper(wrapper, f) self.assertIs(wrapper.__wrapped__, f) self.assertEqual(wrapper.__name__, 'f') self.assertEqual(wrapper.__qualname__, f.__qualname__) self.assertEqual(wrapper.attr, 'This is also a test') self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation') self.assertNotIn('b', wrapper.__annotations__) @unittest.skipIf(sys.flags.optimize >= 2, "Docstrings are omitted with -O2 and above") def test_default_update_doc(self): wrapper, f = self._default_update() self.assertEqual(wrapper.__doc__, 'This is a test') def test_no_update(self): def f(): """This is a test""" pass f.attr = 'This is also a test' def wrapper(): pass functools.update_wrapper(wrapper, f, (), ()) self.check_wrapper(wrapper, f, (), ()) self.assertEqual(wrapper.__name__, 'wrapper') self.assertNotEqual(wrapper.__qualname__, f.__qualname__) self.assertEqual(wrapper.__doc__, None) self.assertEqual(wrapper.__annotations__, {}) self.assertFalse(hasattr(wrapper, 'attr')) def test_selective_update(self): def f(): pass f.attr = 'This is a different test' f.dict_attr = dict(a=1, b=2, c=3) def wrapper(): pass wrapper.dict_attr = {} assign = ('attr',) update = ('dict_attr',) functools.update_wrapper(wrapper, f, assign, update) self.check_wrapper(wrapper, f, assign, update) self.assertEqual(wrapper.__name__, 'wrapper') self.assertNotEqual(wrapper.__qualname__, f.__qualname__) self.assertEqual(wrapper.__doc__, None) self.assertEqual(wrapper.attr, 'This is a different test') self.assertEqual(wrapper.dict_attr, f.dict_attr) def test_missing_attributes(self): def f(): pass def wrapper(): pass wrapper.dict_attr = {} assign = ('attr',) update = ('dict_attr',) # Missing attributes on wrapped object are ignored functools.update_wrapper(wrapper, f, assign, update) self.assertNotIn('attr', wrapper.__dict__) self.assertEqual(wrapper.dict_attr, {}) # Wrapper must have expected attributes for updating del wrapper.dict_attr with self.assertRaises(AttributeError): functools.update_wrapper(wrapper, f, assign, update) wrapper.dict_attr = 1 with self.assertRaises(AttributeError): functools.update_wrapper(wrapper, f, assign, update) @support.requires_docstrings @unittest.skipIf(sys.flags.optimize >= 2, "Docstrings are omitted with -O2 and above") def test_builtin_update(self): # Test for bug #1576241 def wrapper(): pass functools.update_wrapper(wrapper, max) self.assertEqual(wrapper.__name__, 'max') self.assertTrue(wrapper.__doc__.startswith('max(')) self.assertEqual(wrapper.__annotations__, {}) class TestWraps(TestUpdateWrapper): def _default_update(self): def f(): """This is a test""" pass f.attr = 'This is also a test' f.__wrapped__ = "This is still a bald faced lie" @functools.wraps(f) def wrapper(): pass return wrapper, f def test_default_update(self): wrapper, f = self._default_update() self.check_wrapper(wrapper, f) self.assertEqual(wrapper.__name__, 'f') self.assertEqual(wrapper.__qualname__, f.__qualname__) self.assertEqual(wrapper.attr, 'This is also a test') @unittest.skipIf(sys.flags.optimize >= 2, "Docstrings are omitted with -O2 and above") def test_default_update_doc(self): wrapper, _ = self._default_update() self.assertEqual(wrapper.__doc__, 'This is a test') def test_no_update(self): def f(): """This is a test""" pass f.attr = 'This is also a test' @functools.wraps(f, (), ()) def wrapper(): pass self.check_wrapper(wrapper, f, (), ()) self.assertEqual(wrapper.__name__, 'wrapper') self.assertNotEqual(wrapper.__qualname__, f.__qualname__) self.assertEqual(wrapper.__doc__, None) self.assertFalse(hasattr(wrapper, 'attr')) def test_selective_update(self): def f(): pass f.attr = 'This is a different test' f.dict_attr = dict(a=1, b=2, c=3) def add_dict_attr(f): f.dict_attr = {} return f assign = ('attr',) update = ('dict_attr',) @functools.wraps(f, assign, update) @add_dict_attr def wrapper(): pass self.check_wrapper(wrapper, f, assign, update) self.assertEqual(wrapper.__name__, 'wrapper') self.assertNotEqual(wrapper.__qualname__, f.__qualname__) self.assertEqual(wrapper.__doc__, None) self.assertEqual(wrapper.attr, 'This is a different test') self.assertEqual(wrapper.dict_attr, f.dict_attr) @unittest.skipUnless(c_functools, 'requires the C _functools module') class TestReduce(unittest.TestCase): if c_functools: func = c_functools.reduce def test_reduce(self): class Squares: def __init__(self, max): self.max = max self.sofar = [] def __len__(self): return len(self.sofar) def __getitem__(self, i): if not 0 <= i < self.max: raise IndexError n = len(self.sofar) while n <= i: self.sofar.append(n*n) n += 1 return self.sofar[i] def add(x, y): return x + y self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc') self.assertEqual( self.func(add, [['a', 'c'], [], ['d', 'w']], []), ['a','c','d','w'] ) self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040) self.assertEqual( self.func(lambda x, y: x*y, range(2,21), 1), 2432902008176640000 ) self.assertEqual(self.func(add, Squares(10)), 285) self.assertEqual(self.func(add, Squares(10), 0), 285) self.assertEqual(self.func(add, Squares(0), 0), 0) self.assertRaises(TypeError, self.func) self.assertRaises(TypeError, self.func, 42, 42) self.assertRaises(TypeError, self.func, 42, 42, 42) self.assertEqual(self.func(42, "1"), "1") # func is never called with one item self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item self.assertRaises(TypeError, self.func, 42, (42, 42)) self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value self.assertRaises(TypeError, self.func, add, "") self.assertRaises(TypeError, self.func, add, ()) self.assertRaises(TypeError, self.func, add, object()) class TestFailingIter: def __iter__(self): raise RuntimeError self.assertRaises(RuntimeError, self.func, add, TestFailingIter()) self.assertEqual(self.func(add, [], None), None) self.assertEqual(self.func(add, [], 42), 42) class BadSeq: def __getitem__(self, index): raise ValueError self.assertRaises(ValueError, self.func, 42, BadSeq()) # Test reduce()'s use of iterators. def test_iterator_usage(self): class SequenceClass: def __init__(self, n): self.n = n def __getitem__(self, i): if 0 <= i < self.n: return i else: raise IndexError from operator import add self.assertEqual(self.func(add, SequenceClass(5)), 10) self.assertEqual(self.func(add, SequenceClass(5), 42), 52) self.assertRaises(TypeError, self.func, add, SequenceClass(0)) self.assertEqual(self.func(add, SequenceClass(0), 42), 42) self.assertEqual(self.func(add, SequenceClass(1)), 0) self.assertEqual(self.func(add, SequenceClass(1), 42), 42) d = {"one": 1, "two": 2, "three": 3} self.assertEqual(self.func(add, d), "".join(d.keys())) class TestCmpToKey: def test_cmp_to_key(self): def cmp1(x, y): return (x > y) - (x < y) key = self.cmp_to_key(cmp1) self.assertEqual(key(3), key(3)) self.assertGreater(key(3), key(1)) self.assertGreaterEqual(key(3), key(3)) def cmp2(x, y): return int(x) - int(y) key = self.cmp_to_key(cmp2) self.assertEqual(key(4.0), key('4')) self.assertLess(key(2), key('35')) self.assertLessEqual(key(2), key('35')) self.assertNotEqual(key(2), key('35')) def test_cmp_to_key_arguments(self): def cmp1(x, y): return (x > y) - (x < y) key = self.cmp_to_key(mycmp=cmp1) self.assertEqual(key(obj=3), key(obj=3)) self.assertGreater(key(obj=3), key(obj=1)) with self.assertRaises((TypeError, AttributeError)): key(3) > 1 # rhs is not a K object with self.assertRaises((TypeError, AttributeError)): 1 < key(3) # lhs is not a K object with self.assertRaises(TypeError): key = self.cmp_to_key() # too few args with self.assertRaises(TypeError): key = self.cmp_to_key(cmp1, None) # too many args key = self.cmp_to_key(cmp1) with self.assertRaises(TypeError): key() # too few args with self.assertRaises(TypeError): key(None, None) # too many args def test_bad_cmp(self): def cmp1(x, y): raise ZeroDivisionError key = self.cmp_to_key(cmp1) with self.assertRaises(ZeroDivisionError): key(3) > key(1) class BadCmp: def __lt__(self, other): raise ZeroDivisionError def cmp1(x, y): return BadCmp() with self.assertRaises(ZeroDivisionError): key(3) > key(1) def test_obj_field(self): def cmp1(x, y): return (x > y) - (x < y) key = self.cmp_to_key(mycmp=cmp1) self.assertEqual(key(50).obj, 50) def test_sort_int(self): def mycmp(x, y): return y - x self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)), [4, 3, 2, 1, 0]) def test_sort_int_str(self): def mycmp(x, y): x, y = int(x), int(y) return (x > y) - (x < y) values = [5, '3', 7, 2, '0', '1', 4, '10', 1] values = sorted(values, key=self.cmp_to_key(mycmp)) self.assertEqual([int(value) for value in values], [0, 1, 1, 2, 3, 4, 5, 7, 10]) def test_hash(self): def mycmp(x, y): return y - x key = self.cmp_to_key(mycmp) k = key(10) self.assertRaises(TypeError, hash, k) self.assertNotIsInstance(k, collections.Hashable) @unittest.skipUnless(c_functools, 'requires the C _functools module') class TestCmpToKeyC(TestCmpToKey, unittest.TestCase): if c_functools: cmp_to_key = c_functools.cmp_to_key class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase): cmp_to_key = staticmethod(py_functools.cmp_to_key) class TestTotalOrdering(unittest.TestCase): def test_total_ordering_lt(self): @functools.total_ordering class A: def __init__(self, value): self.value = value def __lt__(self, other): return self.value < other.value def __eq__(self, other): return self.value == other.value self.assertTrue(A(1) < A(2)) self.assertTrue(A(2) > A(1)) self.assertTrue(A(1) <= A(2)) self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) >= A(2)) self.assertFalse(A(1) > A(2)) def test_total_ordering_le(self): @functools.total_ordering class A: def __init__(self, value): self.value = value def __le__(self, other): return self.value <= other.value def __eq__(self, other): return self.value == other.value self.assertTrue(A(1) < A(2)) self.assertTrue(A(2) > A(1)) self.assertTrue(A(1) <= A(2)) self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) >= A(2)) self.assertFalse(A(1) >= A(2)) def test_total_ordering_gt(self): @functools.total_ordering class A: def __init__(self, value): self.value = value def __gt__(self, other): return self.value > other.value def __eq__(self, other): return self.value == other.value self.assertTrue(A(1) < A(2)) self.assertTrue(A(2) > A(1)) self.assertTrue(A(1) <= A(2)) self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) >= A(2)) self.assertFalse(A(2) < A(1)) def test_total_ordering_ge(self): @functools.total_ordering class A: def __init__(self, value): self.value = value def __ge__(self, other): return self.value >= other.value def __eq__(self, other): return self.value == other.value self.assertTrue(A(1) < A(2)) self.assertTrue(A(2) > A(1)) self.assertTrue(A(1) <= A(2)) self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) >= A(2)) self.assertFalse(A(2) <= A(1)) def test_total_ordering_no_overwrite(self): # new methods should not overwrite existing @functools.total_ordering class A(int): pass self.assertTrue(A(1) < A(2)) self.assertTrue(A(2) > A(1)) self.assertTrue(A(1) <= A(2)) self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) >= A(2)) def test_no_operations_defined(self): with self.assertRaises(ValueError): @functools.total_ordering class A: pass def test_type_error_when_not_implemented(self): # bug 10042; ensure stack overflow does not occur # when decorated types return NotImplemented @functools.total_ordering class ImplementsLessThan: def __init__(self, value): self.value = value def __eq__(self, other): if isinstance(other, ImplementsLessThan): return self.value == other.value return False def __lt__(self, other): if isinstance(other, ImplementsLessThan): return self.value < other.value return NotImplemented @functools.total_ordering class ImplementsGreaterThan: def __init__(self, value): self.value = value def __eq__(self, other): if isinstance(other, ImplementsGreaterThan): return self.value == other.value return False def __gt__(self, other): if isinstance(other, ImplementsGreaterThan): return self.value > other.value return NotImplemented @functools.total_ordering class ImplementsLessThanEqualTo: def __init__(self, value): self.value = value def __eq__(self, other): if isinstance(other, ImplementsLessThanEqualTo): return self.value == other.value return False def __le__(self, other): if isinstance(other, ImplementsLessThanEqualTo): return self.value <= other.value return NotImplemented @functools.total_ordering class ImplementsGreaterThanEqualTo: def __init__(self, value): self.value = value def __eq__(self, other): if isinstance(other, ImplementsGreaterThanEqualTo): return self.value == other.value return False def __ge__(self, other): if isinstance(other, ImplementsGreaterThanEqualTo): return self.value >= other.value return NotImplemented @functools.total_ordering class ComparatorNotImplemented: def __init__(self, value): self.value = value def __eq__(self, other): if isinstance(other, ComparatorNotImplemented): return self.value == other.value return False def __lt__(self, other): return NotImplemented with self.subTest("LT < 1"), self.assertRaises(TypeError): ImplementsLessThan(-1) < 1 with self.subTest("LT < LE"), self.assertRaises(TypeError): ImplementsLessThan(0) < ImplementsLessThanEqualTo(0) with self.subTest("LT < GT"), self.assertRaises(TypeError): ImplementsLessThan(1) < ImplementsGreaterThan(1) with self.subTest("LE <= LT"), self.assertRaises(TypeError): ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2) with self.subTest("LE <= GE"), self.assertRaises(TypeError): ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3) with self.subTest("GT > GE"), self.assertRaises(TypeError): ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4) with self.subTest("GT > LT"), self.assertRaises(TypeError): ImplementsGreaterThan(5) > ImplementsLessThan(5) with self.subTest("GE >= GT"), self.assertRaises(TypeError): ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6) with self.subTest("GE >= LE"), self.assertRaises(TypeError): ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7) with self.subTest("GE when equal"): a = ComparatorNotImplemented(8) b = ComparatorNotImplemented(8) self.assertEqual(a, b) with self.assertRaises(TypeError): a >= b with self.subTest("LE when equal"): a = ComparatorNotImplemented(9) b = ComparatorNotImplemented(9) self.assertEqual(a, b) with self.assertRaises(TypeError): a <= b def test_pickle(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): for name in '__lt__', '__gt__', '__le__', '__ge__': with self.subTest(method=name, proto=proto): method = getattr(Orderable_LT, name) method_copy = pickle.loads(pickle.dumps(method, proto)) self.assertIs(method_copy, method) @functools.total_ordering class Orderable_LT: def __init__(self, value): self.value = value def __lt__(self, other): return self.value < other.value def __eq__(self, other): return self.value == other.value class TestLRU: def test_lru(self): def orig(x, y): return 3 * x + y f = self.module.lru_cache(maxsize=20)(orig) hits, misses, maxsize, currsize = f.cache_info() self.assertEqual(maxsize, 20) self.assertEqual(currsize, 0) self.assertEqual(hits, 0) self.assertEqual(misses, 0) domain = range(5) for i in range(1000): x, y = choice(domain), choice(domain) actual = f(x, y) expected = orig(x, y) self.assertEqual(actual, expected) hits, misses, maxsize, currsize = f.cache_info() self.assertTrue(hits > misses) self.assertEqual(hits + misses, 1000) self.assertEqual(currsize, 20) f.cache_clear() # test clearing hits, misses, maxsize, currsize = f.cache_info() self.assertEqual(hits, 0) self.assertEqual(misses, 0) self.assertEqual(currsize, 0) f(x, y) hits, misses, maxsize, currsize = f.cache_info() self.assertEqual(hits, 0) self.assertEqual(misses, 1) self.assertEqual(currsize, 1) # Test bypassing the cache self.assertIs(f.__wrapped__, orig) f.__wrapped__(x, y) hits, misses, maxsize, currsize = f.cache_info() self.assertEqual(hits, 0) self.assertEqual(misses, 1) self.assertEqual(currsize, 1) # test size zero (which means "never-cache") @self.module.lru_cache(0) def f(): nonlocal f_cnt f_cnt += 1 return 20 self.assertEqual(f.cache_info().maxsize, 0) f_cnt = 0 for i in range(5): self.assertEqual(f(), 20) self.assertEqual(f_cnt, 5) hits, misses, maxsize, currsize = f.cache_info() self.assertEqual(hits, 0) self.assertEqual(misses, 5) self.assertEqual(currsize, 0) # test size one @self.module.lru_cache(1) def f(): nonlocal f_cnt f_cnt += 1 return 20 self.assertEqual(f.cache_info().maxsize, 1) f_cnt = 0 for i in range(5): self.assertEqual(f(), 20) self.assertEqual(f_cnt, 1) hits, misses, maxsize, currsize = f.cache_info() self.assertEqual(hits, 4) self.assertEqual(misses, 1) self.assertEqual(currsize, 1) # test size two @self.module.lru_cache(2) def f(x): nonlocal f_cnt f_cnt += 1 return x*10 self.assertEqual(f.cache_info().maxsize, 2) f_cnt = 0 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7: # * * * * self.assertEqual(f(x), x*10) self.assertEqual(f_cnt, 4) hits, misses, maxsize, currsize = f.cache_info() self.assertEqual(hits, 12) self.assertEqual(misses, 4) self.assertEqual(currsize, 2) def test_lru_type_error(self): # Regression test for issue #28653. # lru_cache was leaking when one of the arguments # wasn't cacheable. @functools.lru_cache(maxsize=None) def infinite_cache(o): pass @functools.lru_cache(maxsize=10) def limited_cache(o): pass with self.assertRaises(TypeError): infinite_cache([]) with self.assertRaises(TypeError): limited_cache([]) def test_lru_with_maxsize_none(self): @self.module.lru_cache(maxsize=None) def fib(n): if n < 2: return n return fib(n-1) + fib(n-2) self.assertEqual([fib(n) for n in range(16)], [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) self.assertEqual(fib.cache_info(), self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) fib.cache_clear() self.assertEqual(fib.cache_info(), self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) def test_lru_with_maxsize_negative(self): @self.module.lru_cache(maxsize=-10) def eq(n): return n for i in (0, 1): self.assertEqual([eq(n) for n in range(150)], list(range(150))) self.assertEqual(eq.cache_info(), self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1)) def test_lru_with_exceptions(self): # Verify that user_function exceptions get passed through without # creating a hard-to-read chained exception. # http://bugs.python.org/issue13177 for maxsize in (None, 128): @self.module.lru_cache(maxsize) def func(i): return 'abc'[i] self.assertEqual(func(0), 'a') with self.assertRaises(IndexError) as cm: func(15) self.assertIsNone(cm.exception.__context__) # Verify that the previous exception did not result in a cached entry with self.assertRaises(IndexError): func(15) def test_lru_with_types(self): for maxsize in (None, 128): @self.module.lru_cache(maxsize=maxsize, typed=True) def square(x): return x * x self.assertEqual(square(3), 9) self.assertEqual(type(square(3)), type(9)) self.assertEqual(square(3.0), 9.0) self.assertEqual(type(square(3.0)), type(9.0)) self.assertEqual(square(x=3), 9) self.assertEqual(type(square(x=3)), type(9)) self.assertEqual(square(x=3.0), 9.0) self.assertEqual(type(square(x=3.0)), type(9.0)) self.assertEqual(square.cache_info().hits, 4) self.assertEqual(square.cache_info().misses, 4) def test_lru_with_keyword_args(self): @self.module.lru_cache() def fib(n): if n < 2: return n return fib(n=n-1) + fib(n=n-2) self.assertEqual( [fib(n=number) for number in range(16)], [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] ) self.assertEqual(fib.cache_info(), self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16)) fib.cache_clear() self.assertEqual(fib.cache_info(), self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0)) def test_lru_with_keyword_args_maxsize_none(self): @self.module.lru_cache(maxsize=None) def fib(n): if n < 2: return n return fib(n=n-1) + fib(n=n-2) self.assertEqual([fib(n=number) for number in range(16)], [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) self.assertEqual(fib.cache_info(), self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) fib.cache_clear() self.assertEqual(fib.cache_info(), self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) def test_lru_cache_decoration(self): def f(zomg: 'zomg_annotation'): """f doc string""" return 42 g = self.module.lru_cache()(f) for attr in self.module.WRAPPER_ASSIGNMENTS: self.assertEqual(getattr(g, attr), getattr(f, attr)) @unittest.skipUnless(threading, 'This test requires threading.') def test_lru_cache_threaded(self): n, m = 5, 11 def orig(x, y): return 3 * x + y f = self.module.lru_cache(maxsize=n*m)(orig) hits, misses, maxsize, currsize = f.cache_info() self.assertEqual(currsize, 0) start = threading.Event() def full(k): start.wait(10) for _ in range(m): self.assertEqual(f(k, 0), orig(k, 0)) def clear(): start.wait(10) for _ in range(2*m): f.cache_clear() orig_si = sys.getswitchinterval() support.setswitchinterval(1e-6) try: # create n threads in order to fill cache threads = [threading.Thread(target=full, args=[k]) for k in range(n)] with support.start_threads(threads): start.set() hits, misses, maxsize, currsize = f.cache_info() if self.module is py_functools: # XXX: Why can be not equal? self.assertLessEqual(misses, n) self.assertLessEqual(hits, m*n - misses) else: self.assertEqual(misses, n) self.assertEqual(hits, m*n - misses) self.assertEqual(currsize, n) # create n threads in order to fill cache and 1 to clear it threads = [threading.Thread(target=clear)] threads += [threading.Thread(target=full, args=[k]) for k in range(n)] start.clear() with support.start_threads(threads): start.set() finally: sys.setswitchinterval(orig_si) @unittest.skipUnless(threading, 'This test requires threading.') def test_lru_cache_threaded2(self): # Simultaneous call with the same arguments n, m = 5, 7 start = threading.Barrier(n+1) pause = threading.Barrier(n+1) stop = threading.Barrier(n+1) @self.module.lru_cache(maxsize=m*n) def f(x): pause.wait(10) return 3 * x self.assertEqual(f.cache_info(), (0, 0, m*n, 0)) def test(): for i in range(m): start.wait(10) self.assertEqual(f(i), 3 * i) stop.wait(10) threads = [threading.Thread(target=test) for k in range(n)] with support.start_threads(threads): for i in range(m): start.wait(10) stop.reset() pause.wait(10) start.reset() stop.wait(10) pause.reset() self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1)) def test_need_for_rlock(self): # This will deadlock on an LRU cache that uses a regular lock @self.module.lru_cache(maxsize=10) def test_func(x): 'Used to demonstrate a reentrant lru_cache call within a single thread' return x class DoubleEq: 'Demonstrate a reentrant lru_cache call within a single thread' def __init__(self, x): self.x = x def __hash__(self): return self.x def __eq__(self, other): if self.x == 2: test_func(DoubleEq(1)) return self.x == other.x test_func(DoubleEq(1)) # Load the cache test_func(DoubleEq(2)) # Load the cache self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call DoubleEq(2)) # Verify the correct return value def test_early_detection_of_bad_call(self): # Issue #22184 with self.assertRaises(TypeError): @functools.lru_cache def f(): pass def test_lru_method(self): class X(int): f_cnt = 0 @self.module.lru_cache(2) def f(self, x): self.f_cnt += 1 return x*10+self a = X(5) b = X(5) c = X(7) self.assertEqual(X.f.cache_info(), (0, 0, 2, 0)) for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3: self.assertEqual(a.f(x), x*10 + 5) self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0)) self.assertEqual(X.f.cache_info(), (4, 6, 2, 2)) for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2: self.assertEqual(b.f(x), x*10 + 5) self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0)) self.assertEqual(X.f.cache_info(), (10, 10, 2, 2)) for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1: self.assertEqual(c.f(x), x*10 + 7) self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5)) self.assertEqual(X.f.cache_info(), (15, 15, 2, 2)) self.assertEqual(a.f.cache_info(), X.f.cache_info()) self.assertEqual(b.f.cache_info(), X.f.cache_info()) self.assertEqual(c.f.cache_info(), X.f.cache_info()) def test_pickle(self): cls = self.__class__ for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth: for proto in range(pickle.HIGHEST_PROTOCOL + 1): with self.subTest(proto=proto, func=f): f_copy = pickle.loads(pickle.dumps(f, proto)) self.assertIs(f_copy, f) def test_copy(self): cls = self.__class__ def orig(x, y): return 3 * x + y part = self.module.partial(orig, 2) funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, self.module.lru_cache(2)(part)) for f in funcs: with self.subTest(func=f): f_copy = copy.copy(f) self.assertIs(f_copy, f) def test_deepcopy(self): cls = self.__class__ def orig(x, y): return 3 * x + y part = self.module.partial(orig, 2) funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, self.module.lru_cache(2)(part)) for f in funcs: with self.subTest(func=f): f_copy = copy.deepcopy(f) self.assertIs(f_copy, f) @py_functools.lru_cache() def py_cached_func(x, y): return 3 * x + y @c_functools.lru_cache() def c_cached_func(x, y): return 3 * x + y class TestLRUPy(TestLRU, unittest.TestCase): module = py_functools cached_func = py_cached_func, @module.lru_cache() def cached_meth(self, x, y): return 3 * x + y @staticmethod @module.lru_cache() def cached_staticmeth(x, y): return 3 * x + y class TestLRUC(TestLRU, unittest.TestCase): module = c_functools cached_func = c_cached_func, @module.lru_cache() def cached_meth(self, x, y): return 3 * x + y @staticmethod @module.lru_cache() def cached_staticmeth(x, y): return 3 * x + y class TestSingleDispatch(unittest.TestCase): def test_simple_overloads(self): @functools.singledispatch def g(obj): return "base" def g_int(i): return "integer" g.register(int, g_int) self.assertEqual(g("str"), "base") self.assertEqual(g(1), "integer") self.assertEqual(g([1,2,3]), "base") def test_mro(self): @functools.singledispatch def g(obj): return "base" class A: pass class C(A): pass class B(A): pass class D(C, B): pass def g_A(a): return "A" def g_B(b): return "B" g.register(A, g_A) g.register(B, g_B) self.assertEqual(g(A()), "A") self.assertEqual(g(B()), "B") self.assertEqual(g(C()), "A") self.assertEqual(g(D()), "B") def test_register_decorator(self): @functools.singledispatch def g(obj): return "base" @g.register(int) def g_int(i): return "int %s" % (i,) self.assertEqual(g(""), "base") self.assertEqual(g(12), "int 12") self.assertIs(g.dispatch(int), g_int) self.assertIs(g.dispatch(object), g.dispatch(str)) # Note: in the assert above this is not g. # @singledispatch returns the wrapper. def test_wrapping_attributes(self): @functools.singledispatch def g(obj): "Simple test" return "Test" self.assertEqual(g.__name__, "g") if sys.flags.optimize < 2: self.assertEqual(g.__doc__, "Simple test") @unittest.skipUnless(decimal, 'requires _decimal') @support.cpython_only def test_c_classes(self): @functools.singledispatch def g(obj): return "base" @g.register(decimal.DecimalException) def _(obj): return obj.args subn = decimal.Subnormal("Exponent < Emin") rnd = decimal.Rounded("Number got rounded") self.assertEqual(g(subn), ("Exponent < Emin",)) self.assertEqual(g(rnd), ("Number got rounded",)) @g.register(decimal.Subnormal) def _(obj): return "Too small to care." self.assertEqual(g(subn), "Too small to care.") self.assertEqual(g(rnd), ("Number got rounded",)) def test_compose_mro(self): # None of the examples in this test depend on haystack ordering. c = collections mro = functools._compose_mro bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] for haystack in permutations(bases): m = mro(dict, haystack) self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Collection, c.Sized, c.Iterable, c.Container, object]) bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict] for haystack in permutations(bases): m = mro(c.ChainMap, haystack) self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping, c.Collection, c.Sized, c.Iterable, c.Container, object]) # If there's a generic function with implementations registered for # both Sized and Container, passing a defaultdict to it results in an # ambiguous dispatch which will cause a RuntimeError (see # test_mro_conflicts). bases = [c.Container, c.Sized, str] for haystack in permutations(bases): m = mro(c.defaultdict, [c.Sized, c.Container, str]) self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, object]) # MutableSequence below is registered directly on D. In other words, it # precedes MutableMapping which means single dispatch will always # choose MutableSequence here. class D(c.defaultdict): pass c.MutableSequence.register(D) bases = [c.MutableSequence, c.MutableMapping] for haystack in permutations(bases): m = mro(D, bases) self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible, c.defaultdict, dict, c.MutableMapping, c.Mapping, c.Collection, c.Sized, c.Iterable, c.Container, object]) # Container and Callable are registered on different base classes and # a generic function supporting both should always pick the Callable # implementation if a C instance is passed. class C(c.defaultdict): def __call__(self): pass bases = [c.Sized, c.Callable, c.Container, c.Mapping] for haystack in permutations(bases): m = mro(C, haystack) self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping, c.Collection, c.Sized, c.Iterable, c.Container, object]) def test_register_abc(self): c = collections d = {"a": "b"} l = [1, 2, 3] s = {object(), None} f = frozenset(s) t = (1, 2, 3) @functools.singledispatch def g(obj): return "base" self.assertEqual(g(d), "base") self.assertEqual(g(l), "base") self.assertEqual(g(s), "base") self.assertEqual(g(f), "base") self.assertEqual(g(t), "base") g.register(c.Sized, lambda obj: "sized") self.assertEqual(g(d), "sized") self.assertEqual(g(l), "sized") self.assertEqual(g(s), "sized") self.assertEqual(g(f), "sized") self.assertEqual(g(t), "sized") g.register(c.MutableMapping, lambda obj: "mutablemapping") self.assertEqual(g(d), "mutablemapping") self.assertEqual(g(l), "sized") self.assertEqual(g(s), "sized") self.assertEqual(g(f), "sized") self.assertEqual(g(t), "sized") g.register(c.ChainMap, lambda obj: "chainmap") self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered self.assertEqual(g(l), "sized") self.assertEqual(g(s), "sized") self.assertEqual(g(f), "sized") self.assertEqual(g(t), "sized") g.register(c.MutableSequence, lambda obj: "mutablesequence") self.assertEqual(g(d), "mutablemapping") self.assertEqual(g(l), "mutablesequence") self.assertEqual(g(s), "sized") self.assertEqual(g(f), "sized") self.assertEqual(g(t), "sized") g.register(c.MutableSet, lambda obj: "mutableset") self.assertEqual(g(d), "mutablemapping") self.assertEqual(g(l), "mutablesequence") self.assertEqual(g(s), "mutableset") self.assertEqual(g(f), "sized") self.assertEqual(g(t), "sized") g.register(c.Mapping, lambda obj: "mapping") self.assertEqual(g(d), "mutablemapping") # not specific enough self.assertEqual(g(l), "mutablesequence") self.assertEqual(g(s), "mutableset") self.assertEqual(g(f), "sized") self.assertEqual(g(t), "sized") g.register(c.Sequence, lambda obj: "sequence") self.assertEqual(g(d), "mutablemapping") self.assertEqual(g(l), "mutablesequence") self.assertEqual(g(s), "mutableset") self.assertEqual(g(f), "sized") self.assertEqual(g(t), "sequence") g.register(c.Set, lambda obj: "set") self.assertEqual(g(d), "mutablemapping") self.assertEqual(g(l), "mutablesequence") self.assertEqual(g(s), "mutableset") self.assertEqual(g(f), "set") self.assertEqual(g(t), "sequence") g.register(dict, lambda obj: "dict") self.assertEqual(g(d), "dict") self.assertEqual(g(l), "mutablesequence") self.assertEqual(g(s), "mutableset") self.assertEqual(g(f), "set") self.assertEqual(g(t), "sequence") g.register(list, lambda obj: "list") self.assertEqual(g(d), "dict") self.assertEqual(g(l), "list") self.assertEqual(g(s), "mutableset") self.assertEqual(g(f), "set") self.assertEqual(g(t), "sequence") g.register(set, lambda obj: "concrete-set") self.assertEqual(g(d), "dict") self.assertEqual(g(l), "list") self.assertEqual(g(s), "concrete-set") self.assertEqual(g(f), "set") self.assertEqual(g(t), "sequence") g.register(frozenset, lambda obj: "frozen-set") self.assertEqual(g(d), "dict") self.assertEqual(g(l), "list") self.assertEqual(g(s), "concrete-set") self.assertEqual(g(f), "frozen-set") self.assertEqual(g(t), "sequence") g.register(tuple, lambda obj: "tuple") self.assertEqual(g(d), "dict") self.assertEqual(g(l), "list") self.assertEqual(g(s), "concrete-set") self.assertEqual(g(f), "frozen-set") self.assertEqual(g(t), "tuple") def test_c3_abc(self): c = collections mro = functools._c3_mro class A(object): pass class B(A): def __len__(self): return 0 # implies Sized @c.Container.register class C(object): pass class D(object): pass # unrelated class X(D, C, B): def __call__(self): pass # implies Callable expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object] for abcs in permutations([c.Sized, c.Callable, c.Container]): self.assertEqual(mro(X, abcs=abcs), expected) # unrelated ABCs don't appear in the resulting MRO many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable] self.assertEqual(mro(X, abcs=many_abcs), expected) def test_false_meta(self): # see issue23572 class MetaA(type): def __len__(self): return 0 class A(metaclass=MetaA): pass class AA(A): pass @functools.singledispatch def fun(a): return 'base A' @fun.register(A) def _(a): return 'fun A' aa = AA() self.assertEqual(fun(aa), 'fun A') def test_mro_conflicts(self): c = collections @functools.singledispatch def g(arg): return "base" class O(c.Sized): def __len__(self): return 0 o = O() self.assertEqual(g(o), "base") g.register(c.Iterable, lambda arg: "iterable") g.register(c.Container, lambda arg: "container") g.register(c.Sized, lambda arg: "sized") g.register(c.Set, lambda arg: "set") self.assertEqual(g(o), "sized") c.Iterable.register(O) self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ c.Container.register(O) self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ c.Set.register(O) self.assertEqual(g(o), "set") # because c.Set is a subclass of # c.Sized and c.Container class P: pass p = P() self.assertEqual(g(p), "base") c.Iterable.register(P) self.assertEqual(g(p), "iterable") c.Container.register(P) with self.assertRaises(RuntimeError) as re_one: g(p) self.assertIn( str(re_one.exception), (("Ambiguous dispatch: " "or "), ("Ambiguous dispatch: " "or ")), ) class Q(c.Sized): def __len__(self): return 0 q = Q() self.assertEqual(g(q), "sized") c.Iterable.register(Q) self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ c.Set.register(Q) self.assertEqual(g(q), "set") # because c.Set is a subclass of # c.Sized and c.Iterable @functools.singledispatch def h(arg): return "base" @h.register(c.Sized) def _(arg): return "sized" @h.register(c.Container) def _(arg): return "container" # Even though Sized and Container are explicit bases of MutableMapping, # this ABC is implicitly registered on defaultdict which makes all of # MutableMapping's bases implicit as well from defaultdict's # perspective. with self.assertRaises(RuntimeError) as re_two: h(c.defaultdict(lambda: 0)) self.assertIn( str(re_two.exception), (("Ambiguous dispatch: " "or "), ("Ambiguous dispatch: " "or ")), ) class R(c.defaultdict): pass c.MutableSequence.register(R) @functools.singledispatch def i(arg): return "base" @i.register(c.MutableMapping) def _(arg): return "mapping" @i.register(c.MutableSequence) def _(arg): return "sequence" r = R() self.assertEqual(i(r), "sequence") class S: pass class T(S, c.Sized): def __len__(self): return 0 t = T() self.assertEqual(h(t), "sized") c.Container.register(T) self.assertEqual(h(t), "sized") # because it's explicitly in the MRO class U: def __len__(self): return 0 u = U() self.assertEqual(h(u), "sized") # implicit Sized subclass inferred # from the existence of __len__() c.Container.register(U) # There is no preference for registered versus inferred ABCs. with self.assertRaises(RuntimeError) as re_three: h(u) self.assertIn( str(re_three.exception), (("Ambiguous dispatch: " "or "), ("Ambiguous dispatch: " "or ")), ) class V(c.Sized, S): def __len__(self): return 0 @functools.singledispatch def j(arg): return "base" @j.register(S) def _(arg): return "s" @j.register(c.Container) def _(arg): return "container" v = V() self.assertEqual(j(v), "s") c.Container.register(V) self.assertEqual(j(v), "container") # because it ends up right after # Sized in the MRO def test_cache_invalidation(self): from collections import UserDict class TracingDict(UserDict): def __init__(self, *args, **kwargs): super(TracingDict, self).__init__(*args, **kwargs) self.set_ops = [] self.get_ops = [] def __getitem__(self, key): result = self.data[key] self.get_ops.append(key) return result def __setitem__(self, key, value): self.set_ops.append(key) self.data[key] = value def clear(self): self.data.clear() _orig_wkd = functools.WeakKeyDictionary td = TracingDict() functools.WeakKeyDictionary = lambda: td c = collections @functools.singledispatch def g(arg): return "base" d = {} l = [] self.assertEqual(len(td), 0) self.assertEqual(g(d), "base") self.assertEqual(len(td), 1) self.assertEqual(td.get_ops, []) self.assertEqual(td.set_ops, [dict]) self.assertEqual(td.data[dict], g.registry[object]) self.assertEqual(g(l), "base") self.assertEqual(len(td), 2) self.assertEqual(td.get_ops, []) self.assertEqual(td.set_ops, [dict, list]) self.assertEqual(td.data[dict], g.registry[object]) self.assertEqual(td.data[list], g.registry[object]) self.assertEqual(td.data[dict], td.data[list]) self.assertEqual(g(l), "base") self.assertEqual(g(d), "base") self.assertEqual(td.get_ops, [list, dict]) self.assertEqual(td.set_ops, [dict, list]) g.register(list, lambda arg: "list") self.assertEqual(td.get_ops, [list, dict]) self.assertEqual(len(td), 0) self.assertEqual(g(d), "base") self.assertEqual(len(td), 1) self.assertEqual(td.get_ops, [list, dict]) self.assertEqual(td.set_ops, [dict, list, dict]) self.assertEqual(td.data[dict], functools._find_impl(dict, g.registry)) self.assertEqual(g(l), "list") self.assertEqual(len(td), 2) self.assertEqual(td.get_ops, [list, dict]) self.assertEqual(td.set_ops, [dict, list, dict, list]) self.assertEqual(td.data[list], functools._find_impl(list, g.registry)) class X: pass c.MutableMapping.register(X) # Will not invalidate the cache, # not using ABCs yet. self.assertEqual(g(d), "base") self.assertEqual(g(l), "list") self.assertEqual(td.get_ops, [list, dict, dict, list]) self.assertEqual(td.set_ops, [dict, list, dict, list]) g.register(c.Sized, lambda arg: "sized") self.assertEqual(len(td), 0) self.assertEqual(g(d), "sized") self.assertEqual(len(td), 1) self.assertEqual(td.get_ops, [list, dict, dict, list]) self.assertEqual(td.set_ops, [dict, list, dict, list, dict]) self.assertEqual(g(l), "list") self.assertEqual(len(td), 2) self.assertEqual(td.get_ops, [list, dict, dict, list]) self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) self.assertEqual(g(l), "list") self.assertEqual(g(d), "sized") self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict]) self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) g.dispatch(list) g.dispatch(dict) self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict, list, dict]) self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) c.MutableSet.register(X) # Will invalidate the cache. self.assertEqual(len(td), 2) # Stale cache. self.assertEqual(g(l), "list") self.assertEqual(len(td), 1) g.register(c.MutableMapping, lambda arg: "mutablemapping") self.assertEqual(len(td), 0) self.assertEqual(g(d), "mutablemapping") self.assertEqual(len(td), 1) self.assertEqual(g(l), "list") self.assertEqual(len(td), 2) g.register(dict, lambda arg: "dict") self.assertEqual(g(d), "dict") self.assertEqual(g(l), "list") g._clear_cache() self.assertEqual(len(td), 0) functools.WeakKeyDictionary = _orig_wkd if __name__ == '__main__': unittest.main()