bpo-42195: Override _CallableGenericAlias's __getitem__ (GH-23915)

Added `__getitem__` for `_CallableGenericAlias` so that it returns a subclass (itself) of `types.GenericAlias` rather than the default behavior of returning a plain `types.GenericAlias`. This fixes `repr` issues occuring after `TypeVar` substitution arising from the previous behavior.
This commit is contained in:
kj 2020-12-24 10:47:40 +08:00 committed by GitHub
parent eee1c7745a
commit 6dd3da3cf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 1 deletions

View File

@ -434,7 +434,7 @@ class _CallableGenericAlias(GenericAlias):
raise TypeError( raise TypeError(
"Callable must be used as Callable[[arg, ...], result].") "Callable must be used as Callable[[arg, ...], result].")
t_args, t_result = args t_args, t_result = args
if isinstance(t_args, list): if isinstance(t_args, (list, tuple)):
ga_args = tuple(t_args) + (t_result,) ga_args = tuple(t_args) + (t_result,)
# This relaxes what t_args can be on purpose to allow things like # This relaxes what t_args can be on purpose to allow things like
# PEP 612 ParamSpec. Responsibility for whether a user is using # PEP 612 ParamSpec. Responsibility for whether a user is using
@ -456,6 +456,16 @@ class _CallableGenericAlias(GenericAlias):
args = list(args[:-1]), args[-1] args = list(args[:-1]), args[-1]
return _CallableGenericAlias, (Callable, args) return _CallableGenericAlias, (Callable, args)
def __getitem__(self, item):
# Called during TypeVar substitution, returns the custom subclass
# rather than the default types.GenericAlias object.
ga = super().__getitem__(item)
args = ga.__args__
t_result = args[-1]
t_args = args[:-1]
args = (t_args, t_result)
return _CallableGenericAlias(Callable, args)
def _type_repr(obj): def _type_repr(obj):
"""Return the repr() of an object, special-casing types (internal helper). """Return the repr() of an object, special-casing types (internal helper).

View File

@ -347,6 +347,12 @@ class BaseTest(unittest.TestCase):
self.assertEqual(C2[int, float, str], Callable[[int, float], str]) self.assertEqual(C2[int, float, str], Callable[[int, float], str])
self.assertEqual(C3[int], Callable[..., int]) self.assertEqual(C3[int], Callable[..., int])
# multi chaining
C4 = C2[int, V, str]
self.assertEqual(repr(C4).split(".")[-1], "Callable[[int, ~V], str]")
self.assertEqual(repr(C4[dict]).split(".")[-1], "Callable[[int, dict], str]")
self.assertEqual(C4[dict], Callable[[int, dict], str])
with self.subTest("Testing type erasure"): with self.subTest("Testing type erasure"):
class C1(Callable): class C1(Callable):
def __call__(self): def __call__(self):