bpo-40396: Support GenericAlias in the typing functions. (GH-19718)

This commit is contained in:
Serhiy Storchaka 2020-04-26 21:21:08 +03:00 committed by GitHub
parent cfaf4c09ab
commit 68b352a698
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 6 deletions

View File

@ -22,7 +22,7 @@ from typing import NewType
from typing import NamedTuple, TypedDict from typing import NamedTuple, TypedDict
from typing import IO, TextIO, BinaryIO from typing import IO, TextIO, BinaryIO
from typing import Pattern, Match from typing import Pattern, Match
from typing import Annotated from typing import Annotated, ForwardRef
import abc import abc
import typing import typing
import weakref import weakref
@ -1756,11 +1756,17 @@ class GenericTests(BaseTestCase):
def test_generic_forward_ref(self): def test_generic_forward_ref(self):
def foobar(x: List[List['CC']]): ... def foobar(x: List[List['CC']]): ...
def foobar2(x: list[list[ForwardRef('CC')]]): ...
class CC: ... class CC: ...
self.assertEqual( self.assertEqual(
get_type_hints(foobar, globals(), locals()), get_type_hints(foobar, globals(), locals()),
{'x': List[List[CC]]} {'x': List[List[CC]]}
) )
self.assertEqual(
get_type_hints(foobar2, globals(), locals()),
{'x': list[list[CC]]}
)
T = TypeVar('T') T = TypeVar('T')
AT = Tuple[T, ...] AT = Tuple[T, ...]
def barfoo(x: AT): ... def barfoo(x: AT): ...
@ -2446,6 +2452,12 @@ class ForwardRefTests(BaseTestCase):
self.assertEqual(get_type_hints(foo, globals(), locals()), self.assertEqual(get_type_hints(foo, globals(), locals()),
{'a': Tuple[T]}) {'a': Tuple[T]})
def foo(a: tuple[ForwardRef('T')]):
pass
self.assertEqual(get_type_hints(foo, globals(), locals()),
{'a': tuple[T]})
def test_forward_recursion_actually(self): def test_forward_recursion_actually(self):
def namespace1(): def namespace1():
a = typing.ForwardRef('A') a = typing.ForwardRef('A')
@ -2909,6 +2921,18 @@ class GetTypeHintTests(BaseTestCase):
get_type_hints(foobar, globals(), locals(), include_extras=True), get_type_hints(foobar, globals(), locals(), include_extras=True),
{'x': List[Annotated[int, (1, 10)]]} {'x': List[Annotated[int, (1, 10)]]}
) )
def foobar(x: list[ForwardRef('X')]): ...
X = Annotated[int, (1, 10)]
self.assertEqual(
get_type_hints(foobar, globals(), locals()),
{'x': list[int]}
)
self.assertEqual(
get_type_hints(foobar, globals(), locals(), include_extras=True),
{'x': list[Annotated[int, (1, 10)]]}
)
BA = Tuple[Annotated[T, (1, 0)], ...] BA = Tuple[Annotated[T, (1, 0)], ...]
def barfoo(x: BA): ... def barfoo(x: BA): ...
self.assertEqual(get_type_hints(barfoo, globals(), locals())['x'], Tuple[T, ...]) self.assertEqual(get_type_hints(barfoo, globals(), locals())['x'], Tuple[T, ...])
@ -2916,12 +2940,22 @@ class GetTypeHintTests(BaseTestCase):
get_type_hints(barfoo, globals(), locals(), include_extras=True)['x'], get_type_hints(barfoo, globals(), locals(), include_extras=True)['x'],
BA BA
) )
BA = tuple[Annotated[T, (1, 0)], ...]
def barfoo(x: BA): ...
self.assertEqual(get_type_hints(barfoo, globals(), locals())['x'], tuple[T, ...])
self.assertIs(
get_type_hints(barfoo, globals(), locals(), include_extras=True)['x'],
BA
)
def barfoo2(x: typing.Callable[..., Annotated[List[T], "const"]], def barfoo2(x: typing.Callable[..., Annotated[List[T], "const"]],
y: typing.Union[int, Annotated[T, "mutable"]]): ... y: typing.Union[int, Annotated[T, "mutable"]]): ...
self.assertEqual( self.assertEqual(
get_type_hints(barfoo2, globals(), locals()), get_type_hints(barfoo2, globals(), locals()),
{'x': typing.Callable[..., List[T]], 'y': typing.Union[int, T]} {'x': typing.Callable[..., List[T]], 'y': typing.Union[int, T]}
) )
BA2 = typing.Callable[..., List[T]] BA2 = typing.Callable[..., List[T]]
def barfoo3(x: BA2): ... def barfoo3(x: BA2): ...
self.assertIs( self.assertIs(
@ -2972,6 +3006,9 @@ class GetUtilitiesTestCase(TestCase):
self.assertIs(get_origin(Generic[T]), Generic) self.assertIs(get_origin(Generic[T]), Generic)
self.assertIs(get_origin(List[Tuple[T, T]][int]), list) self.assertIs(get_origin(List[Tuple[T, T]][int]), list)
self.assertIs(get_origin(Annotated[T, 'thing']), Annotated) self.assertIs(get_origin(Annotated[T, 'thing']), Annotated)
self.assertIs(get_origin(List), list)
self.assertIs(get_origin(list[int]), list)
self.assertIs(get_origin(list), None)
def test_get_args(self): def test_get_args(self):
T = TypeVar('T') T = TypeVar('T')
@ -2993,6 +3030,9 @@ class GetUtilitiesTestCase(TestCase):
self.assertEqual(get_args(Tuple[int, ...]), (int, ...)) self.assertEqual(get_args(Tuple[int, ...]), (int, ...))
self.assertEqual(get_args(Tuple[()]), ((),)) self.assertEqual(get_args(Tuple[()]), ((),))
self.assertEqual(get_args(Annotated[T, 'one', 2, ['three']]), (T, 'one', 2, ['three'])) self.assertEqual(get_args(Annotated[T, 'one', 2, ['three']]), (T, 'one', 2, ['three']))
self.assertEqual(get_args(List), (typing.T,))
self.assertEqual(get_args(list[int]), (int,))
self.assertEqual(get_args(list), ())
class CollectionsAbcTests(BaseTestCase): class CollectionsAbcTests(BaseTestCase):

View File

@ -191,7 +191,7 @@ def _subs_tvars(tp, tvars, subs):
"""Substitute type variables 'tvars' with substitutions 'subs'. """Substitute type variables 'tvars' with substitutions 'subs'.
These two must have the same length. These two must have the same length.
""" """
if not isinstance(tp, _GenericAlias): if not isinstance(tp, (_GenericAlias, GenericAlias)):
return tp return tp
new_args = list(tp.__args__) new_args = list(tp.__args__)
for a, arg in enumerate(tp.__args__): for a, arg in enumerate(tp.__args__):
@ -203,7 +203,10 @@ def _subs_tvars(tp, tvars, subs):
new_args[a] = _subs_tvars(arg, tvars, subs) new_args[a] = _subs_tvars(arg, tvars, subs)
if tp.__origin__ is Union: if tp.__origin__ is Union:
return Union[tuple(new_args)] return Union[tuple(new_args)]
return tp.copy_with(tuple(new_args)) if isinstance(tp, GenericAlias):
return GenericAlias(tp.__origin__, tuple(new_args))
else:
return tp.copy_with(tuple(new_args))
def _check_generic(cls, parameters): def _check_generic(cls, parameters):
@ -278,6 +281,11 @@ def _eval_type(t, globalns, localns):
res = t.copy_with(ev_args) res = t.copy_with(ev_args)
res._special = t._special res._special = t._special
return res return res
if isinstance(t, GenericAlias):
ev_args = tuple(_eval_type(a, globalns, localns) for a in t.__args__)
if ev_args == t.__args__:
return t
return GenericAlias(t.__origin__, ev_args)
return t return t
@ -1368,6 +1376,11 @@ def _strip_annotations(t):
res = t.copy_with(stripped_args) res = t.copy_with(stripped_args)
res._special = t._special res._special = t._special
return res return res
if isinstance(t, GenericAlias):
stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
if stripped_args == t.__args__:
return t
return GenericAlias(t.__origin__, stripped_args)
return t return t
@ -1387,7 +1400,7 @@ def get_origin(tp):
""" """
if isinstance(tp, _AnnotatedAlias): if isinstance(tp, _AnnotatedAlias):
return Annotated return Annotated
if isinstance(tp, _GenericAlias): if isinstance(tp, (_GenericAlias, GenericAlias)):
return tp.__origin__ return tp.__origin__
if tp is Generic: if tp is Generic:
return Generic return Generic
@ -1407,9 +1420,9 @@ def get_args(tp):
""" """
if isinstance(tp, _AnnotatedAlias): if isinstance(tp, _AnnotatedAlias):
return (tp.__origin__,) + tp.__metadata__ return (tp.__origin__,) + tp.__metadata__
if isinstance(tp, _GenericAlias): if isinstance(tp, (_GenericAlias, GenericAlias)):
res = tp.__args__ res = tp.__args__
if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis: if tp.__origin__ is collections.abc.Callable and res[0] is not Ellipsis:
res = (list(res[:-1]), res[-1]) res = (list(res[:-1]), res[-1])
return res return res
return () return ()

View File

@ -0,0 +1,3 @@
Functions :func:`typing.get_origin`, :func:`typing.get_args` and
:func:`typing.get_type_hints` support now generic aliases like
``list[int]``.