Issue #12428: Add a pure Python implementation of functools.partial().

Patch by Brian Thorne.
This commit is contained in:
Antoine Pitrou 2012-11-13 21:35:40 +01:00
parent 65a35dcadd
commit b5b3714168
4 changed files with 167 additions and 73 deletions

View File

@ -11,7 +11,10 @@
__all__ = ['update_wrapper', 'wraps', 'WRAPPER_ASSIGNMENTS', 'WRAPPER_UPDATES',
'total_ordering', 'cmp_to_key', 'lru_cache', 'reduce', 'partial']
from _functools import partial, reduce
try:
from _functools import reduce
except ImportError:
pass
from collections import namedtuple
try:
from _thread import allocate_lock as Lock
@ -136,6 +139,29 @@ except ImportError:
pass
################################################################################
### partial() argument application
################################################################################
def partial(func, *args, **keywords):
"""new function with partial application of the given arguments
and keywords.
"""
def newfunc(*fargs, **fkeywords):
newkeywords = keywords.copy()
newkeywords.update(fkeywords)
return func(*(args + fargs), **newkeywords)
newfunc.func = func
newfunc.args = args
newfunc.keywords = keywords
return newfunc
try:
from _functools import partial
except ImportError:
pass
################################################################################
### LRU Cache function decorator
################################################################################

View File

@ -1,4 +1,3 @@
import functools
import collections
import sys
import unittest
@ -7,17 +6,31 @@ from weakref import proxy
import pickle
from random import choice
@staticmethod
def PythonPartial(func, *args, **keywords):
'Pure Python approximation of partial()'
def newfunc(*fargs, **fkeywords):
newkeywords = keywords.copy()
newkeywords.update(fkeywords)
return func(*(args + fargs), **newkeywords)
newfunc.func = func
newfunc.args = args
newfunc.keywords = keywords
return newfunc
import functools
original_functools = functools
py_functools = support.import_fresh_module('functools', blocked=['_functools'])
c_functools = support.import_fresh_module('functools', fresh=['_functools'])
class BaseTest(unittest.TestCase):
"""Base class required for testing C and Py implementations."""
def setUp(self):
# The module must be explicitly set so that the proper
# interaction between the c module and the python module
# can be controlled.
self.partial = self.module.partial
super(BaseTest, self).setUp()
class BaseTestC(BaseTest):
module = c_functools
class BaseTestPy(BaseTest):
module = py_functools
PythonPartial = py_functools.partial
def capture(*args, **kw):
"""capture all positional and keyword arguments"""
@ -27,31 +40,32 @@ def signature(part):
""" return the signature of a partial object """
return (part.func, part.args, part.keywords, part.__dict__)
class TestPartial(unittest.TestCase):
class TestPartial(object):
thetype = functools.partial
partial = functools.partial
def test_basic_examples(self):
p = self.thetype(capture, 1, 2, a=10, b=20)
p = self.partial(capture, 1, 2, a=10, b=20)
self.assertTrue(callable(p))
self.assertEqual(p(3, 4, b=30, c=40),
((1, 2, 3, 4), dict(a=10, b=30, c=40)))
p = self.thetype(map, lambda x: x*10)
p = self.partial(map, lambda x: x*10)
self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
def test_attributes(self):
p = self.thetype(capture, 1, 2, a=10, b=20)
p = self.partial(capture, 1, 2, a=10, b=20)
# attributes should be readable
self.assertEqual(p.func, capture)
self.assertEqual(p.args, (1, 2))
self.assertEqual(p.keywords, dict(a=10, b=20))
# attributes should not be writable
if not isinstance(self.thetype, type):
if not isinstance(self.partial, type):
return
self.assertRaises(AttributeError, setattr, p, 'func', map)
self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
p = self.thetype(hex)
p = self.partial(hex)
try:
del p.__dict__
except TypeError:
@ -60,9 +74,9 @@ class TestPartial(unittest.TestCase):
self.fail('partial object allowed __dict__ to be deleted')
def test_argument_checking(self):
self.assertRaises(TypeError, self.thetype) # need at least a func arg
self.assertRaises(TypeError, self.partial) # need at least a func arg
try:
self.thetype(2)()
self.partial(2)()
except TypeError:
pass
else:
@ -73,7 +87,7 @@ class TestPartial(unittest.TestCase):
def func(a=10, b=20):
return a
d = {'a':3}
p = self.thetype(func, a=5)
p = self.partial(func, a=5)
self.assertEqual(p(**d), 3)
self.assertEqual(d, {'a':3})
p(b=7)
@ -82,20 +96,20 @@ class TestPartial(unittest.TestCase):
def test_arg_combinations(self):
# exercise special code paths for zero args in either partial
# object or the caller
p = self.thetype(capture)
p = self.partial(capture)
self.assertEqual(p(), ((), {}))
self.assertEqual(p(1,2), ((1,2), {}))
p = self.thetype(capture, 1, 2)
p = self.partial(capture, 1, 2)
self.assertEqual(p(), ((1,2), {}))
self.assertEqual(p(3,4), ((1,2,3,4), {}))
def test_kw_combinations(self):
# exercise special code paths for no keyword args in
# either the partial object or the caller
p = self.thetype(capture)
p = self.partial(capture)
self.assertEqual(p(), ((), {}))
self.assertEqual(p(a=1), ((), {'a':1}))
p = self.thetype(capture, a=1)
p = self.partial(capture, a=1)
self.assertEqual(p(), ((), {'a':1}))
self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
# keyword args in the call override those in the partial object
@ -104,7 +118,7 @@ class TestPartial(unittest.TestCase):
def test_positional(self):
# make sure positional arguments are captured correctly
for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
p = self.thetype(capture, *args)
p = self.partial(capture, *args)
expected = args + ('x',)
got, empty = p('x')
self.assertTrue(expected == got and empty == {})
@ -112,14 +126,14 @@ class TestPartial(unittest.TestCase):
def test_keyword(self):
# make sure keyword arguments are captured correctly
for a in ['a', 0, None, 3.5]:
p = self.thetype(capture, a=a)
p = self.partial(capture, a=a)
expected = {'a':a,'x':None}
empty, got = p(x=None)
self.assertTrue(expected == got and empty == ())
def test_no_side_effects(self):
# make sure there are no side effects that affect subsequent calls
p = self.thetype(capture, 0, a=1)
p = self.partial(capture, 0, a=1)
args1, kw1 = p(1, b=2)
self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
args2, kw2 = p()
@ -128,13 +142,13 @@ class TestPartial(unittest.TestCase):
def test_error_propagation(self):
def f(x, y):
x / y
self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
def test_weakref(self):
f = self.thetype(int, base=16)
f = self.partial(int, base=16)
p = proxy(f)
self.assertEqual(f.func, p.func)
f = None
@ -142,9 +156,9 @@ class TestPartial(unittest.TestCase):
def test_with_bound_and_unbound_methods(self):
data = list(map(str, range(10)))
join = self.thetype(str.join, '')
join = self.partial(str.join, '')
self.assertEqual(join(data), '0123456789')
join = self.thetype(''.join)
join = self.partial(''.join)
self.assertEqual(join(data), '0123456789')
def test_repr(self):
@ -152,49 +166,57 @@ class TestPartial(unittest.TestCase):
args_repr = ', '.join(repr(a) for a in args)
kwargs = {'a': object(), 'b': object()}
kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
if self.thetype is functools.partial:
if self.partial is functools.partial:
name = 'functools.partial'
else:
name = self.thetype.__name__
name = self.partial.__name__
f = self.thetype(capture)
f = self.partial(capture)
self.assertEqual('{}({!r})'.format(name, capture),
repr(f))
f = self.thetype(capture, *args)
f = self.partial(capture, *args)
self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
repr(f))
f = self.thetype(capture, **kwargs)
f = self.partial(capture, **kwargs)
self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
repr(f))
f = self.thetype(capture, *args, **kwargs)
f = self.partial(capture, *args, **kwargs)
self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
repr(f))
def test_pickle(self):
f = self.thetype(signature, 'asdf', bar=True)
f = self.partial(signature, 'asdf', bar=True)
f.add_something_to__dict__ = True
f_copy = pickle.loads(pickle.dumps(f))
self.assertEqual(signature(f), signature(f_copy))
class PartialSubclass(functools.partial):
class TestPartialC(BaseTestC, TestPartial):
pass
class TestPartialSubclass(TestPartial):
class TestPartialPy(BaseTestPy, TestPartial):
thetype = PartialSubclass
def test_pickle(self):
raise unittest.SkipTest("Python implementation of partial isn't picklable")
class TestPythonPartial(TestPartial):
def test_repr(self):
raise unittest.SkipTest("Python implementation of partial uses own repr")
thetype = PythonPartial
class TestPartialCSubclass(BaseTestC, TestPartial):
# the python version hasn't a nice repr
def test_repr(self): pass
class PartialSubclass(c_functools.partial):
pass
# the python version isn't picklable
def test_pickle(self): pass
partial = staticmethod(PartialSubclass)
class TestPartialPySubclass(TestPartialPy):
class PartialSubclass(c_functools.partial):
pass
partial = staticmethod(PartialSubclass)
class TestUpdateWrapper(unittest.TestCase):
@ -320,7 +342,7 @@ class TestWraps(TestUpdateWrapper):
self.assertEqual(wrapper.__qualname__, f.__qualname__)
self.assertEqual(wrapper.attr, 'This is also a test')
@unittest.skipIf(not sys.flags.optimize <= 1,
@unittest.skipIf(sys.flags.optimize >= 2,
"Docstrings are omitted with -O2 and above")
def test_default_update_doc(self):
wrapper, _ = self._default_update()
@ -441,24 +463,28 @@ class TestReduce(unittest.TestCase):
d = {"one": 1, "two": 2, "three": 3}
self.assertEqual(self.func(add, d), "".join(d.keys()))
class TestCmpToKey(unittest.TestCase):
class TestCmpToKey(object):
def test_cmp_to_key(self):
def cmp1(x, y):
return (x > y) - (x < y)
key = functools.cmp_to_key(cmp1)
key = self.cmp_to_key(cmp1)
self.assertEqual(key(3), key(3))
self.assertGreater(key(3), key(1))
self.assertGreaterEqual(key(3), key(3))
def cmp2(x, y):
return int(x) - int(y)
key = functools.cmp_to_key(cmp2)
key = self.cmp_to_key(cmp2)
self.assertEqual(key(4.0), key('4'))
self.assertLess(key(2), key('35'))
self.assertLessEqual(key(2), key('35'))
self.assertNotEqual(key(2), key('35'))
def test_cmp_to_key_arguments(self):
def cmp1(x, y):
return (x > y) - (x < y)
key = functools.cmp_to_key(mycmp=cmp1)
key = self.cmp_to_key(mycmp=cmp1)
self.assertEqual(key(obj=3), key(obj=3))
self.assertGreater(key(obj=3), key(obj=1))
with self.assertRaises((TypeError, AttributeError)):
@ -466,10 +492,10 @@ class TestCmpToKey(unittest.TestCase):
with self.assertRaises((TypeError, AttributeError)):
1 < key(3) # lhs is not a K object
with self.assertRaises(TypeError):
key = functools.cmp_to_key() # too few args
key = self.cmp_to_key() # too few args
with self.assertRaises(TypeError):
key = functools.cmp_to_key(cmp1, None) # too many args
key = functools.cmp_to_key(cmp1)
key = self.module.cmp_to_key(cmp1, None) # too many args
key = self.cmp_to_key(cmp1)
with self.assertRaises(TypeError):
key() # too few args
with self.assertRaises(TypeError):
@ -478,7 +504,7 @@ class TestCmpToKey(unittest.TestCase):
def test_bad_cmp(self):
def cmp1(x, y):
raise ZeroDivisionError
key = functools.cmp_to_key(cmp1)
key = self.cmp_to_key(cmp1)
with self.assertRaises(ZeroDivisionError):
key(3) > key(1)
@ -493,13 +519,13 @@ class TestCmpToKey(unittest.TestCase):
def test_obj_field(self):
def cmp1(x, y):
return (x > y) - (x < y)
key = functools.cmp_to_key(mycmp=cmp1)
key = self.cmp_to_key(mycmp=cmp1)
self.assertEqual(key(50).obj, 50)
def test_sort_int(self):
def mycmp(x, y):
return y - x
self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
[4, 3, 2, 1, 0])
def test_sort_int_str(self):
@ -507,18 +533,24 @@ class TestCmpToKey(unittest.TestCase):
x, y = int(x), int(y)
return (x > y) - (x < y)
values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
values = sorted(values, key=functools.cmp_to_key(mycmp))
values = sorted(values, key=self.cmp_to_key(mycmp))
self.assertEqual([int(value) for value in values],
[0, 1, 1, 2, 3, 4, 5, 7, 10])
def test_hash(self):
def mycmp(x, y):
return y - x
key = functools.cmp_to_key(mycmp)
key = self.cmp_to_key(mycmp)
k = key(10)
self.assertRaises(TypeError, hash, k)
self.assertNotIsInstance(k, collections.Hashable)
class TestCmpToKeyC(BaseTestC, TestCmpToKey):
cmp_to_key = c_functools.cmp_to_key
class TestCmpToKeyPy(BaseTestPy, TestCmpToKey):
cmp_to_key = staticmethod(py_functools.cmp_to_key)
class TestTotalOrdering(unittest.TestCase):
def test_total_ordering_lt(self):
@ -728,7 +760,7 @@ class TestLRU(unittest.TestCase):
# Verify that user_function exceptions get passed through without
# creating a hard-to-read chained exception.
# http://bugs.python.org/issue13177
for maxsize in (None, 100):
for maxsize in (None, 128):
@functools.lru_cache(maxsize)
def func(i):
return 'abc'[i]
@ -741,7 +773,7 @@ class TestLRU(unittest.TestCase):
func(15)
def test_lru_with_types(self):
for maxsize in (None, 100):
for maxsize in (None, 128):
@functools.lru_cache(maxsize=maxsize, typed=True)
def square(x):
return x * x
@ -756,14 +788,46 @@ class TestLRU(unittest.TestCase):
self.assertEqual(square.cache_info().hits, 4)
self.assertEqual(square.cache_info().misses, 4)
def test_lru_with_keyword_args(self):
@functools.lru_cache()
def fib(n):
if n < 2:
return n
return fib(n=n-1) + fib(n=n-2)
self.assertEqual(
[fib(n=number) for number in range(16)],
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
)
self.assertEqual(fib.cache_info(),
functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
fib.cache_clear()
self.assertEqual(fib.cache_info(),
functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
def test_lru_with_keyword_args_maxsize_none(self):
@functools.lru_cache(maxsize=None)
def fib(n):
if n < 2:
return n
return fib(n=n-1) + fib(n=n-2)
self.assertEqual([fib(n=number) for number in range(16)],
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
self.assertEqual(fib.cache_info(),
functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
fib.cache_clear()
self.assertEqual(fib.cache_info(),
functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
def test_main(verbose=None):
test_classes = (
TestPartial,
TestPartialSubclass,
TestPythonPartial,
TestPartialC,
TestPartialPy,
TestPartialCSubclass,
TestPartialPySubclass,
TestUpdateWrapper,
TestTotalOrdering,
TestCmpToKey,
TestCmpToKeyC,
TestCmpToKeyPy,
TestWraps,
TestReduce,
TestLRU,

View File

@ -1166,6 +1166,7 @@ Tobias Thelen
Nicolas M. Thiéry
James Thomas
Robin Thomas
Brian Thorne
Stephen Thorne
Jeremy Thurgood
Eric Tiedemann

View File

@ -124,6 +124,9 @@ Core and Builtins
Library
-------
- Issue #12428: Add a pure Python implementation of functools.partial().
Patch by Brian Thorne.
- Issue #16140: The subprocess module no longer double closes its child
subprocess.PIPE parent file descriptors on child error prior to exec().