inspect.signautre: Fix functools.partial support. Issue #21117

This commit is contained in:
Yury Selivanov 2014-04-08 11:30:45 -04:00
parent 1d1d95bf83
commit 3f73ca23cf
3 changed files with 146 additions and 97 deletions

View File

@ -1518,7 +1518,8 @@ def _signature_get_partial(wrapped_sig, partial, extra_args=()):
on it. on it.
""" """
new_params = OrderedDict(wrapped_sig.parameters.items()) old_params = wrapped_sig.parameters
new_params = OrderedDict(old_params.items())
partial_args = partial.args or () partial_args = partial.args or ()
partial_keywords = partial.keywords or {} partial_keywords = partial.keywords or {}
@ -1532,32 +1533,57 @@ def _signature_get_partial(wrapped_sig, partial, extra_args=()):
msg = 'partial object {!r} has incorrect arguments'.format(partial) msg = 'partial object {!r} has incorrect arguments'.format(partial)
raise ValueError(msg) from ex raise ValueError(msg) from ex
for arg_name, arg_value in ba.arguments.items():
param = new_params[arg_name]
if arg_name in partial_keywords:
# We set a new default value, because the following code
# is correct:
#
# >>> def foo(a): print(a)
# >>> print(partial(partial(foo, a=10), a=20)())
# 20
# >>> print(partial(partial(foo, a=10), a=20)(a=30))
# 30
#
# So, with 'partial' objects, passing a keyword argument is
# like setting a new default value for the corresponding
# parameter
#
# We also mark this parameter with '_partial_kwarg'
# flag. Later, in '_bind', the 'default' value of this
# parameter will be added to 'kwargs', to simulate
# the 'functools.partial' real call.
new_params[arg_name] = param.replace(default=arg_value,
_partial_kwarg=True)
elif (param.kind not in (_VAR_KEYWORD, _VAR_POSITIONAL) and transform_to_kwonly = False
not param._partial_kwarg): for param_name, param in old_params.items():
new_params.pop(arg_name) try:
arg_value = ba.arguments[param_name]
except KeyError:
pass
else:
if param.kind is _POSITIONAL_ONLY:
# If positional-only parameter is bound by partial,
# it effectively disappears from the signature
new_params.pop(param_name)
continue
if param.kind is _POSITIONAL_OR_KEYWORD:
if param_name in partial_keywords:
# This means that this parameter, and all parameters
# after it should be keyword-only (and var-positional
# should be removed). Here's why. Consider the following
# function:
# foo(a, b, *args, c):
# pass
#
# "partial(foo, a='spam')" will have the following
# signature: "(*, a='spam', b, c)". Because attempting
# to call that partial with "(10, 20)" arguments will
# raise a TypeError, saying that "a" argument received
# multiple values.
transform_to_kwonly = True
# Set the new default value
new_params[param_name] = param.replace(default=arg_value)
else:
# was passed as a positional argument
new_params.pop(param.name)
continue
if param.kind is _KEYWORD_ONLY:
# Set the new default value
new_params[param_name] = param.replace(default=arg_value)
if transform_to_kwonly:
assert param.kind is not _POSITIONAL_ONLY
if param.kind is _POSITIONAL_OR_KEYWORD:
new_param = new_params[param_name].replace(kind=_KEYWORD_ONLY)
new_params[param_name] = new_param
new_params.move_to_end(param_name)
elif param.kind in (_KEYWORD_ONLY, _VAR_KEYWORD):
new_params.move_to_end(param_name)
elif param.kind is _VAR_POSITIONAL:
new_params.pop(param.name)
return wrapped_sig.replace(parameters=new_params.values()) return wrapped_sig.replace(parameters=new_params.values())
@ -2103,7 +2129,7 @@ class Parameter:
`Parameter.KEYWORD_ONLY`, `Parameter.VAR_KEYWORD`. `Parameter.KEYWORD_ONLY`, `Parameter.VAR_KEYWORD`.
""" """
__slots__ = ('_name', '_kind', '_default', '_annotation', '_partial_kwarg') __slots__ = ('_name', '_kind', '_default', '_annotation')
POSITIONAL_ONLY = _POSITIONAL_ONLY POSITIONAL_ONLY = _POSITIONAL_ONLY
POSITIONAL_OR_KEYWORD = _POSITIONAL_OR_KEYWORD POSITIONAL_OR_KEYWORD = _POSITIONAL_OR_KEYWORD
@ -2113,8 +2139,7 @@ class Parameter:
empty = _empty empty = _empty
def __init__(self, name, kind, *, default=_empty, annotation=_empty, def __init__(self, name, kind, *, default=_empty, annotation=_empty):
_partial_kwarg=False):
if kind not in (_POSITIONAL_ONLY, _POSITIONAL_OR_KEYWORD, if kind not in (_POSITIONAL_ONLY, _POSITIONAL_OR_KEYWORD,
_VAR_POSITIONAL, _KEYWORD_ONLY, _VAR_KEYWORD): _VAR_POSITIONAL, _KEYWORD_ONLY, _VAR_KEYWORD):
@ -2139,17 +2164,13 @@ class Parameter:
self._name = name self._name = name
self._partial_kwarg = _partial_kwarg
def __reduce__(self): def __reduce__(self):
return (type(self), return (type(self),
(self._name, self._kind), (self._name, self._kind),
{'_partial_kwarg': self._partial_kwarg, {'_default': self._default,
'_default': self._default,
'_annotation': self._annotation}) '_annotation': self._annotation})
def __setstate__(self, state): def __setstate__(self, state):
self._partial_kwarg = state['_partial_kwarg']
self._default = state['_default'] self._default = state['_default']
self._annotation = state['_annotation'] self._annotation = state['_annotation']
@ -2169,8 +2190,8 @@ class Parameter:
def kind(self): def kind(self):
return self._kind return self._kind
def replace(self, *, name=_void, kind=_void, annotation=_void, def replace(self, *, name=_void, kind=_void,
default=_void, _partial_kwarg=_void): annotation=_void, default=_void):
"""Creates a customized copy of the Parameter.""" """Creates a customized copy of the Parameter."""
if name is _void: if name is _void:
@ -2185,11 +2206,7 @@ class Parameter:
if default is _void: if default is _void:
default = self._default default = self._default
if _partial_kwarg is _void: return type(self)(name, kind, default=default, annotation=annotation)
_partial_kwarg = self._partial_kwarg
return type(self)(name, kind, default=default, annotation=annotation,
_partial_kwarg=_partial_kwarg)
def __str__(self): def __str__(self):
kind = self.kind kind = self.kind
@ -2215,17 +2232,6 @@ class Parameter:
id(self), self) id(self), self)
def __eq__(self, other): def __eq__(self, other):
# NB: We deliberately do not compare '_partial_kwarg' attributes
# here. Imagine we have a following situation:
#
# def foo(a, b=1): pass
# def bar(a, b): pass
# bar2 = functools.partial(bar, b=1)
#
# For the above scenario, signatures for `foo` and `bar2` should
# be equal. '_partial_kwarg' attribute is an internal flag, to
# distinguish between keyword parameters with defaults and
# keyword parameters which got their defaults from functools.partial
return (issubclass(other.__class__, Parameter) and return (issubclass(other.__class__, Parameter) and
self._name == other._name and self._name == other._name and
self._kind == other._kind and self._kind == other._kind and
@ -2265,12 +2271,7 @@ class BoundArguments:
def args(self): def args(self):
args = [] args = []
for param_name, param in self._signature.parameters.items(): for param_name, param in self._signature.parameters.items():
if (param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) or if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY):
param._partial_kwarg):
# Keyword arguments mapped by 'functools.partial'
# (Parameter._partial_kwarg is True) are mapped
# in 'BoundArguments.kwargs', along with VAR_KEYWORD &
# KEYWORD_ONLY
break break
try: try:
@ -2295,8 +2296,7 @@ class BoundArguments:
kwargs_started = False kwargs_started = False
for param_name, param in self._signature.parameters.items(): for param_name, param in self._signature.parameters.items():
if not kwargs_started: if not kwargs_started:
if (param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) or if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY):
param._partial_kwarg):
kwargs_started = True kwargs_started = True
else: else:
if param_name not in self.arguments: if param_name not in self.arguments:
@ -2378,18 +2378,14 @@ class Signature:
name = param.name name = param.name
if kind < top_kind: if kind < top_kind:
msg = 'wrong parameter order: {} before {}' msg = 'wrong parameter order: {!r} before {!r}'
msg = msg.format(top_kind, kind) msg = msg.format(top_kind, kind)
raise ValueError(msg) raise ValueError(msg)
elif kind > top_kind: elif kind > top_kind:
kind_defaults = False kind_defaults = False
top_kind = kind top_kind = kind
if (kind in (_POSITIONAL_ONLY, _POSITIONAL_OR_KEYWORD) and if kind in (_POSITIONAL_ONLY, _POSITIONAL_OR_KEYWORD):
not param._partial_kwarg):
# If we have a positional-only or positional-or-keyword
# parameter, that does not have its default value set
# by 'functools.partial' or other "partial" signature:
if param.default is _empty: if param.default is _empty:
if kind_defaults: if kind_defaults:
# No default for this parameter, but the # No default for this parameter, but the
@ -2570,15 +2566,6 @@ class Signature:
parameters_ex = () parameters_ex = ()
arg_vals = iter(args) arg_vals = iter(args)
if partial:
# Support for binding arguments to 'functools.partial' objects.
# See 'functools.partial' case in 'signature()' implementation
# for details.
for param_name, param in self.parameters.items():
if (param._partial_kwarg and param_name not in kwargs):
# Simulating 'functools.partial' behavior
kwargs[param_name] = param.default
while True: while True:
# Let's iterate through the positional arguments and corresponding # Let's iterate through the positional arguments and corresponding
# parameters # parameters

View File

@ -1689,13 +1689,11 @@ class TestSignatureObject(unittest.TestCase):
foo_partial = functools.partial(foo, a=1) foo_partial = functools.partial(foo, a=1)
sig = inspect.signature(foo_partial) sig = inspect.signature(foo_partial)
self.assertTrue(sig.parameters['a']._partial_kwarg)
for ver in range(pickle.HIGHEST_PROTOCOL + 1): for ver in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(pickle_ver=ver, subclass=False): with self.subTest(pickle_ver=ver, subclass=False):
sig_pickled = pickle.loads(pickle.dumps(sig, ver)) sig_pickled = pickle.loads(pickle.dumps(sig, ver))
self.assertEqual(sig, sig_pickled) self.assertEqual(sig, sig_pickled)
self.assertTrue(sig_pickled.parameters['a']._partial_kwarg)
# Test that basic sub-classing works # Test that basic sub-classing works
sig = inspect.signature(foo) sig = inspect.signature(foo)
@ -2005,6 +2003,8 @@ class TestSignatureObject(unittest.TestCase):
def test_signature_on_partial(self): def test_signature_on_partial(self):
from functools import partial from functools import partial
Parameter = inspect.Parameter
def test(): def test():
pass pass
@ -2040,15 +2040,22 @@ class TestSignatureObject(unittest.TestCase):
self.assertEqual(self.signature(partial(test, b=1, c=2)), self.assertEqual(self.signature(partial(test, b=1, c=2)),
((('a', ..., ..., "positional_or_keyword"), ((('a', ..., ..., "positional_or_keyword"),
('b', 1, ..., "positional_or_keyword"), ('b', 1, ..., "keyword_only"),
('c', 2, ..., "keyword_only"), ('c', 2, ..., "keyword_only"),
('d', ..., ..., "keyword_only")), ('d', ..., ..., "keyword_only")),
...)) ...))
self.assertEqual(self.signature(partial(test, 0, b=1, c=2)), self.assertEqual(self.signature(partial(test, 0, b=1, c=2)),
((('b', 1, ..., "positional_or_keyword"), ((('b', 1, ..., "keyword_only"),
('c', 2, ..., "keyword_only"), ('c', 2, ..., "keyword_only"),
('d', ..., ..., "keyword_only"),), ('d', ..., ..., "keyword_only")),
...))
self.assertEqual(self.signature(partial(test, a=1)),
((('a', 1, ..., "keyword_only"),
('b', ..., ..., "keyword_only"),
('c', ..., ..., "keyword_only"),
('d', ..., ..., "keyword_only")),
...)) ...))
def test(a, *args, b, **kwargs): def test(a, *args, b, **kwargs):
@ -2060,13 +2067,18 @@ class TestSignatureObject(unittest.TestCase):
('kwargs', ..., ..., "var_keyword")), ('kwargs', ..., ..., "var_keyword")),
...)) ...))
self.assertEqual(self.signature(partial(test, a=1)),
((('a', 1, ..., "keyword_only"),
('b', ..., ..., "keyword_only"),
('kwargs', ..., ..., "var_keyword")),
...))
self.assertEqual(self.signature(partial(test, 1, 2, 3)), self.assertEqual(self.signature(partial(test, 1, 2, 3)),
((('args', ..., ..., "var_positional"), ((('args', ..., ..., "var_positional"),
('b', ..., ..., "keyword_only"), ('b', ..., ..., "keyword_only"),
('kwargs', ..., ..., "var_keyword")), ('kwargs', ..., ..., "var_keyword")),
...)) ...))
self.assertEqual(self.signature(partial(test, 1, 2, 3, test=True)), self.assertEqual(self.signature(partial(test, 1, 2, 3, test=True)),
((('args', ..., ..., "var_positional"), ((('args', ..., ..., "var_positional"),
('b', ..., ..., "keyword_only"), ('b', ..., ..., "keyword_only"),
@ -2113,7 +2125,7 @@ class TestSignatureObject(unittest.TestCase):
return a return a
_foo = partial(partial(foo, a=10), a=20) _foo = partial(partial(foo, a=10), a=20)
self.assertEqual(self.signature(_foo), self.assertEqual(self.signature(_foo),
((('a', 20, ..., "positional_or_keyword"),), ((('a', 20, ..., "keyword_only"),),
...)) ...))
# check that we don't have any side-effects in signature(), # check that we don't have any side-effects in signature(),
# and the partial object is still functioning # and the partial object is still functioning
@ -2122,42 +2134,87 @@ class TestSignatureObject(unittest.TestCase):
def foo(a, b, c): def foo(a, b, c):
return a, b, c return a, b, c
_foo = partial(partial(foo, 1, b=20), b=30) _foo = partial(partial(foo, 1, b=20), b=30)
self.assertEqual(self.signature(_foo), self.assertEqual(self.signature(_foo),
((('b', 30, ..., "positional_or_keyword"), ((('b', 30, ..., "keyword_only"),
('c', ..., ..., "positional_or_keyword")), ('c', ..., ..., "keyword_only")),
...)) ...))
self.assertEqual(_foo(c=10), (1, 30, 10)) self.assertEqual(_foo(c=10), (1, 30, 10))
_foo = partial(_foo, 2) # now 'b' has two values -
# positional and keyword
with self.assertRaisesRegex(ValueError, "has incorrect arguments"):
inspect.signature(_foo)
def foo(a, b, c, *, d): def foo(a, b, c, *, d):
return a, b, c, d return a, b, c, d
_foo = partial(partial(foo, d=20, c=20), b=10, d=30) _foo = partial(partial(foo, d=20, c=20), b=10, d=30)
self.assertEqual(self.signature(_foo), self.assertEqual(self.signature(_foo),
((('a', ..., ..., "positional_or_keyword"), ((('a', ..., ..., "positional_or_keyword"),
('b', 10, ..., "positional_or_keyword"), ('b', 10, ..., "keyword_only"),
('c', 20, ..., "positional_or_keyword"), ('c', 20, ..., "keyword_only"),
('d', 30, ..., "keyword_only")), ('d', 30, ..., "keyword_only"),
),
...)) ...))
ba = inspect.signature(_foo).bind(a=200, b=11) ba = inspect.signature(_foo).bind(a=200, b=11)
self.assertEqual(_foo(*ba.args, **ba.kwargs), (200, 11, 20, 30)) self.assertEqual(_foo(*ba.args, **ba.kwargs), (200, 11, 20, 30))
def foo(a=1, b=2, c=3): def foo(a=1, b=2, c=3):
return a, b, c return a, b, c
_foo = partial(foo, a=10, c=13) _foo = partial(foo, c=13) # (a=1, b=2, *, c=13)
ba = inspect.signature(_foo).bind(11)
ba = inspect.signature(_foo).bind(a=11)
self.assertEqual(_foo(*ba.args, **ba.kwargs), (11, 2, 13)) self.assertEqual(_foo(*ba.args, **ba.kwargs), (11, 2, 13))
ba = inspect.signature(_foo).bind(11, 12) ba = inspect.signature(_foo).bind(11, 12)
self.assertEqual(_foo(*ba.args, **ba.kwargs), (11, 12, 13)) self.assertEqual(_foo(*ba.args, **ba.kwargs), (11, 12, 13))
ba = inspect.signature(_foo).bind(11, b=12) ba = inspect.signature(_foo).bind(11, b=12)
self.assertEqual(_foo(*ba.args, **ba.kwargs), (11, 12, 13)) self.assertEqual(_foo(*ba.args, **ba.kwargs), (11, 12, 13))
ba = inspect.signature(_foo).bind(b=12) ba = inspect.signature(_foo).bind(b=12)
self.assertEqual(_foo(*ba.args, **ba.kwargs), (10, 12, 13)) self.assertEqual(_foo(*ba.args, **ba.kwargs), (1, 12, 13))
_foo = partial(_foo, b=10)
ba = inspect.signature(_foo).bind(12, 14) _foo = partial(_foo, b=10, c=20)
self.assertEqual(_foo(*ba.args, **ba.kwargs), (12, 14, 13)) ba = inspect.signature(_foo).bind(12)
self.assertEqual(_foo(*ba.args, **ba.kwargs), (12, 10, 20))
def foo(a, b, c, d, **kwargs):
pass
sig = inspect.signature(foo)
params = sig.parameters.copy()
params['a'] = params['a'].replace(kind=Parameter.POSITIONAL_ONLY)
params['b'] = params['b'].replace(kind=Parameter.POSITIONAL_ONLY)
foo.__signature__ = inspect.Signature(params.values())
sig = inspect.signature(foo)
self.assertEqual(str(sig), '(a, b, /, c, d, **kwargs)')
self.assertEqual(self.signature(partial(foo, 1)),
((('b', ..., ..., 'positional_only'),
('c', ..., ..., 'positional_or_keyword'),
('d', ..., ..., 'positional_or_keyword'),
('kwargs', ..., ..., 'var_keyword')),
...))
self.assertEqual(self.signature(partial(foo, 1, 2)),
((('c', ..., ..., 'positional_or_keyword'),
('d', ..., ..., 'positional_or_keyword'),
('kwargs', ..., ..., 'var_keyword')),
...))
self.assertEqual(self.signature(partial(foo, 1, 2, 3)),
((('d', ..., ..., 'positional_or_keyword'),
('kwargs', ..., ..., 'var_keyword')),
...))
self.assertEqual(self.signature(partial(foo, 1, 2, c=3)),
((('c', 3, ..., 'keyword_only'),
('d', ..., ..., 'keyword_only'),
('kwargs', ..., ..., 'var_keyword')),
...))
self.assertEqual(self.signature(partial(foo, 1, c=3)),
((('b', ..., ..., 'positional_only'),
('c', 3, ..., 'keyword_only'),
('d', ..., ..., 'keyword_only'),
('kwargs', ..., ..., 'var_keyword')),
...))
def test_signature_on_partialmethod(self): def test_signature_on_partialmethod(self):
from functools import partialmethod from functools import partialmethod

View File

@ -149,6 +149,11 @@ Library
(Original patches by Hirokazu Yamamoto and Amaury Forgeot d'Arc, with (Original patches by Hirokazu Yamamoto and Amaury Forgeot d'Arc, with
suggested wording by David Gutteridge) suggested wording by David Gutteridge)
- Issue #21117: Fix inspect.signature to better support functools.partial.
Due to the specifics of functools.partial implementation,
positional-or-keyword arguments passed as keyword arguments become
keyword-only.
IDLE IDLE
---- ----