Issue #27137: align Python & C implementations of functools.partial
The pure Python fallback implementation of functools.partial now matches the behaviour of its accelerated C counterpart for subclassing, pickling and text representation purposes. Patch by Emanuel Barry and Serhiy Storchaka.
This commit is contained in:
parent
eddc4b7272
commit
457fc9a69e
|
@ -21,6 +21,7 @@ from abc import get_cache_token
|
|||
from collections import namedtuple
|
||||
from types import MappingProxyType
|
||||
from weakref import WeakKeyDictionary
|
||||
from reprlib import recursive_repr
|
||||
try:
|
||||
from _thread import RLock
|
||||
except ImportError:
|
||||
|
@ -237,11 +238,24 @@ except ImportError:
|
|||
################################################################################
|
||||
|
||||
# Purely functional, no descriptor behaviour
|
||||
def partial(func, *args, **keywords):
|
||||
class partial:
|
||||
"""New function with partial application of the given arguments
|
||||
and keywords.
|
||||
"""
|
||||
if hasattr(func, 'func'):
|
||||
|
||||
__slots__ = "func", "args", "keywords", "__dict__", "__weakref__"
|
||||
|
||||
def __new__(*args, **keywords):
|
||||
if not args:
|
||||
raise TypeError("descriptor '__new__' of partial needs an argument")
|
||||
if len(args) < 2:
|
||||
raise TypeError("type 'partial' takes at least one argument")
|
||||
cls, func, *args = args
|
||||
if not callable(func):
|
||||
raise TypeError("the first argument must be callable")
|
||||
args = tuple(args)
|
||||
|
||||
if hasattr(func, "func"):
|
||||
args = func.args + args
|
||||
tmpkw = func.keywords.copy()
|
||||
tmpkw.update(keywords)
|
||||
|
@ -249,14 +263,58 @@ def partial(func, *args, **keywords):
|
|||
del tmpkw
|
||||
func = func.func
|
||||
|
||||
def newfunc(*fargs, **fkeywords):
|
||||
newkeywords = keywords.copy()
|
||||
newkeywords.update(fkeywords)
|
||||
return func(*(args + fargs), **newkeywords)
|
||||
newfunc.func = func
|
||||
newfunc.args = args
|
||||
newfunc.keywords = keywords
|
||||
return newfunc
|
||||
self = super(partial, cls).__new__(cls)
|
||||
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.keywords = keywords
|
||||
return self
|
||||
|
||||
def __call__(*args, **keywords):
|
||||
if not args:
|
||||
raise TypeError("descriptor '__call__' of partial needs an argument")
|
||||
self, *args = args
|
||||
newkeywords = self.keywords.copy()
|
||||
newkeywords.update(keywords)
|
||||
return self.func(*self.args, *args, **newkeywords)
|
||||
|
||||
@recursive_repr()
|
||||
def __repr__(self):
|
||||
qualname = type(self).__qualname__
|
||||
args = [repr(self.func)]
|
||||
args.extend(repr(x) for x in self.args)
|
||||
args.extend(f"{k}={v!r}" for (k, v) in self.keywords.items())
|
||||
if type(self).__module__ == "functools":
|
||||
return f"functools.{qualname}({', '.join(args)})"
|
||||
return f"{qualname}({', '.join(args)})"
|
||||
|
||||
def __reduce__(self):
|
||||
return type(self), (self.func,), (self.func, self.args,
|
||||
self.keywords or None, self.__dict__ or None)
|
||||
|
||||
def __setstate__(self, state):
|
||||
if not isinstance(state, tuple):
|
||||
raise TypeError("argument to __setstate__ must be a tuple")
|
||||
if len(state) != 4:
|
||||
raise TypeError(f"expected 4 items in state, got {len(state)}")
|
||||
func, args, kwds, namespace = state
|
||||
if (not callable(func) or not isinstance(args, tuple) or
|
||||
(kwds is not None and not isinstance(kwds, dict)) or
|
||||
(namespace is not None and not isinstance(namespace, dict))):
|
||||
raise TypeError("invalid partial state")
|
||||
|
||||
args = tuple(args) # just in case it's a subclass
|
||||
if kwds is None:
|
||||
kwds = {}
|
||||
elif type(kwds) is not dict: # XXX does it need to be *exactly* dict?
|
||||
kwds = dict(kwds)
|
||||
if namespace is None:
|
||||
namespace = {}
|
||||
|
||||
self.__dict__ = namespace
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.keywords = kwds
|
||||
|
||||
try:
|
||||
from _functools import partial
|
||||
|
|
|
@ -8,6 +8,7 @@ import sys
|
|||
from test import support
|
||||
import unittest
|
||||
from weakref import proxy
|
||||
import contextlib
|
||||
try:
|
||||
import threading
|
||||
except ImportError:
|
||||
|
@ -20,6 +21,14 @@ c_functools = support.import_fresh_module('functools', fresh=['_functools'])
|
|||
|
||||
decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
|
||||
|
||||
@contextlib.contextmanager
|
||||
def replaced_module(name, replacement):
|
||||
original_module = sys.modules[name]
|
||||
sys.modules[name] = replacement
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.modules[name] = original_module
|
||||
|
||||
def capture(*args, **kw):
|
||||
"""capture all positional and keyword arguments"""
|
||||
|
@ -167,58 +176,35 @@ class TestPartial:
|
|||
p2.new_attr = 'spam'
|
||||
self.assertEqual(p2.new_attr, 'spam')
|
||||
|
||||
|
||||
@unittest.skipUnless(c_functools, 'requires the C _functools module')
|
||||
class TestPartialC(TestPartial, unittest.TestCase):
|
||||
if c_functools:
|
||||
partial = c_functools.partial
|
||||
|
||||
def test_attributes_unwritable(self):
|
||||
# attributes should not be writable
|
||||
p = self.partial(capture, 1, 2, a=10, b=20)
|
||||
self.assertRaises(AttributeError, setattr, p, 'func', map)
|
||||
self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
|
||||
self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
|
||||
|
||||
p = self.partial(hex)
|
||||
try:
|
||||
del p.__dict__
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
self.fail('partial object allowed __dict__ to be deleted')
|
||||
|
||||
def test_repr(self):
|
||||
args = (object(), object())
|
||||
args_repr = ', '.join(repr(a) for a in args)
|
||||
kwargs = {'a': object(), 'b': object()}
|
||||
kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
|
||||
'b={b!r}, a={a!r}'.format_map(kwargs)]
|
||||
if self.partial is c_functools.partial:
|
||||
if self.partial in (c_functools.partial, py_functools.partial):
|
||||
name = 'functools.partial'
|
||||
else:
|
||||
name = self.partial.__name__
|
||||
|
||||
f = self.partial(capture)
|
||||
self.assertEqual('{}({!r})'.format(name, capture),
|
||||
repr(f))
|
||||
self.assertEqual(f'{name}({capture!r})', repr(f))
|
||||
|
||||
f = self.partial(capture, *args)
|
||||
self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
|
||||
repr(f))
|
||||
self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
|
||||
|
||||
f = self.partial(capture, **kwargs)
|
||||
self.assertIn(repr(f),
|
||||
['{}({!r}, {})'.format(name, capture, kwargs_repr)
|
||||
[f'{name}({capture!r}, {kwargs_repr})'
|
||||
for kwargs_repr in kwargs_reprs])
|
||||
|
||||
f = self.partial(capture, *args, **kwargs)
|
||||
self.assertIn(repr(f),
|
||||
['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr)
|
||||
[f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
|
||||
for kwargs_repr in kwargs_reprs])
|
||||
|
||||
def test_recursive_repr(self):
|
||||
if self.partial is c_functools.partial:
|
||||
if self.partial in (c_functools.partial, py_functools.partial):
|
||||
name = 'functools.partial'
|
||||
else:
|
||||
name = self.partial.__name__
|
||||
|
@ -226,25 +212,26 @@ class TestPartialC(TestPartial, unittest.TestCase):
|
|||
f = self.partial(capture)
|
||||
f.__setstate__((f, (), {}, {}))
|
||||
try:
|
||||
self.assertEqual(repr(f), '%s(%s(...))' % (name, name))
|
||||
self.assertEqual(repr(f), '%s(...)' % (name,))
|
||||
finally:
|
||||
f.__setstate__((capture, (), {}, {}))
|
||||
|
||||
f = self.partial(capture)
|
||||
f.__setstate__((capture, (f,), {}, {}))
|
||||
try:
|
||||
self.assertEqual(repr(f), '%s(%r, %s(...))' % (name, capture, name))
|
||||
self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
|
||||
finally:
|
||||
f.__setstate__((capture, (), {}, {}))
|
||||
|
||||
f = self.partial(capture)
|
||||
f.__setstate__((capture, (), {'a': f}, {}))
|
||||
try:
|
||||
self.assertEqual(repr(f), '%s(%r, a=%s(...))' % (name, capture, name))
|
||||
self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
|
||||
finally:
|
||||
f.__setstate__((capture, (), {}, {}))
|
||||
|
||||
def test_pickle(self):
|
||||
with self.AllowPickle():
|
||||
f = self.partial(signature, ['asdf'], bar=[True])
|
||||
f.attr = []
|
||||
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||
|
@ -274,11 +261,13 @@ class TestPartialC(TestPartial, unittest.TestCase):
|
|||
def test_setstate(self):
|
||||
f = self.partial(signature)
|
||||
f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
|
||||
|
||||
self.assertEqual(signature(f),
|
||||
(capture, (1,), dict(a=10), dict(attr=[])))
|
||||
self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
|
||||
|
||||
f.__setstate__((capture, (1,), dict(a=10), None))
|
||||
|
||||
self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
|
||||
self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
|
||||
|
||||
|
@ -325,6 +314,7 @@ class TestPartialC(TestPartial, unittest.TestCase):
|
|||
self.assertIs(type(r[0]), tuple)
|
||||
|
||||
def test_recursive_pickle(self):
|
||||
with self.AllowPickle():
|
||||
f = self.partial(capture)
|
||||
f.__setstate__((f, (), {}, {}))
|
||||
try:
|
||||
|
@ -375,24 +365,60 @@ class TestPartialC(TestPartial, unittest.TestCase):
|
|||
f = self.partial(object)
|
||||
self.assertRaises(TypeError, f.__setstate__, BadSequence())
|
||||
|
||||
@unittest.skipUnless(c_functools, 'requires the C _functools module')
|
||||
class TestPartialC(TestPartial, unittest.TestCase):
|
||||
if c_functools:
|
||||
partial = c_functools.partial
|
||||
|
||||
class AllowPickle:
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, type, value, tb):
|
||||
return False
|
||||
|
||||
def test_attributes_unwritable(self):
|
||||
# attributes should not be writable
|
||||
p = self.partial(capture, 1, 2, a=10, b=20)
|
||||
self.assertRaises(AttributeError, setattr, p, 'func', map)
|
||||
self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
|
||||
self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
|
||||
|
||||
p = self.partial(hex)
|
||||
try:
|
||||
del p.__dict__
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
self.fail('partial object allowed __dict__ to be deleted')
|
||||
|
||||
class TestPartialPy(TestPartial, unittest.TestCase):
|
||||
partial = staticmethod(py_functools.partial)
|
||||
partial = py_functools.partial
|
||||
|
||||
class AllowPickle:
|
||||
def __init__(self):
|
||||
self._cm = replaced_module("functools", py_functools)
|
||||
def __enter__(self):
|
||||
return self._cm.__enter__()
|
||||
def __exit__(self, type, value, tb):
|
||||
return self._cm.__exit__(type, value, tb)
|
||||
|
||||
if c_functools:
|
||||
class PartialSubclass(c_functools.partial):
|
||||
class CPartialSubclass(c_functools.partial):
|
||||
pass
|
||||
|
||||
class PyPartialSubclass(py_functools.partial):
|
||||
pass
|
||||
|
||||
@unittest.skipUnless(c_functools, 'requires the C _functools module')
|
||||
class TestPartialCSubclass(TestPartialC):
|
||||
if c_functools:
|
||||
partial = PartialSubclass
|
||||
partial = CPartialSubclass
|
||||
|
||||
# partial subclasses are not optimized for nested calls
|
||||
test_nested_optimization = None
|
||||
|
||||
class TestPartialPySubclass(TestPartialPy):
|
||||
partial = PyPartialSubclass
|
||||
|
||||
class TestPartialMethod(unittest.TestCase):
|
||||
|
||||
|
@ -683,9 +709,10 @@ class TestWraps(TestUpdateWrapper):
|
|||
self.assertEqual(wrapper.attr, 'This is a different test')
|
||||
self.assertEqual(wrapper.dict_attr, f.dict_attr)
|
||||
|
||||
|
||||
@unittest.skipUnless(c_functools, 'requires the C _functools module')
|
||||
class TestReduce(unittest.TestCase):
|
||||
func = functools.reduce
|
||||
if c_functools:
|
||||
func = c_functools.reduce
|
||||
|
||||
def test_reduce(self):
|
||||
class Squares:
|
||||
|
|
|
@ -135,6 +135,11 @@ Core and Builtins
|
|||
Library
|
||||
-------
|
||||
|
||||
- Issue #27137: the pure Python fallback implementation of ``functools.partial``
|
||||
now matches the behaviour of its accelerated C counterpart for subclassing,
|
||||
pickling and text representation purposes. Patch by Emanuel Barry and
|
||||
Serhiy Storchaka.
|
||||
|
||||
- Issue #28019: itertools.count() no longer rounds non-integer step in range
|
||||
between 1.0 and 2.0 to 1.
|
||||
|
||||
|
|
|
@ -229,7 +229,7 @@ partial_repr(partialobject *pto)
|
|||
if (status != 0) {
|
||||
if (status < 0)
|
||||
return NULL;
|
||||
return PyUnicode_FromFormat("%s(...)", Py_TYPE(pto)->tp_name);
|
||||
return PyUnicode_FromString("...");
|
||||
}
|
||||
|
||||
arglist = PyUnicode_FromString("");
|
||||
|
|
Loading…
Reference in New Issue