bpo-40397: Refactor typing._GenericAlias (GH-19719)

Make the design more object-oriented.
Split _GenericAlias on two almost independent classes: for special
generic aliases like List and for parametrized generic aliases like List[int].
Add specialized subclasses for Callable, Callable[...], Tuple and Union[...].
This commit is contained in:
Serhiy Storchaka 2020-05-07 04:09:33 +03:00 committed by GitHub
parent 470aac4d8e
commit c1c7d8ead9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 209 additions and 202 deletions

View File

@ -181,34 +181,11 @@ def _collect_type_vars(types):
for t in types:
if isinstance(t, TypeVar) and t not in tvars:
tvars.append(t)
if ((isinstance(t, _GenericAlias) and not t._special)
or isinstance(t, GenericAlias)):
if isinstance(t, (_GenericAlias, GenericAlias)):
tvars.extend([t for t in t.__parameters__ if t not in tvars])
return tuple(tvars)
def _subs_tvars(tp, tvars, subs):
"""Substitute type variables 'tvars' with substitutions 'subs'.
These two must have the same length.
"""
if not isinstance(tp, (_GenericAlias, GenericAlias)):
return tp
new_args = list(tp.__args__)
for a, arg in enumerate(tp.__args__):
if isinstance(arg, TypeVar):
for i, tvar in enumerate(tvars):
if arg == tvar:
new_args[a] = subs[i]
else:
new_args[a] = _subs_tvars(arg, tvars, subs)
if tp.__origin__ is Union:
return Union[tuple(new_args)]
if isinstance(tp, GenericAlias):
return GenericAlias(tp.__origin__, tuple(new_args))
else:
return tp.copy_with(tuple(new_args))
def _check_generic(cls, parameters):
"""Check correct count for parameters of a generic cls (internal helper).
This gives a nice error message in case of count mismatch.
@ -229,7 +206,7 @@ def _remove_dups_flatten(parameters):
# Flatten out Union[Union[...], ...].
params = []
for p in parameters:
if isinstance(p, _GenericAlias) and p.__origin__ is Union:
if isinstance(p, _UnionGenericAlias):
params.extend(p.__args__)
elif isinstance(p, tuple) and len(p) > 0 and p[0] is Union:
params.extend(p[1:])
@ -274,18 +251,14 @@ def _eval_type(t, globalns, localns):
"""
if isinstance(t, ForwardRef):
return t._evaluate(globalns, localns)
if isinstance(t, _GenericAlias):
if isinstance(t, (_GenericAlias, GenericAlias)):
ev_args = tuple(_eval_type(a, globalns, localns) for a in t.__args__)
if ev_args == t.__args__:
return t
res = t.copy_with(ev_args)
res._special = t._special
return res
if isinstance(t, GenericAlias):
ev_args = tuple(_eval_type(a, globalns, localns) for a in t.__args__)
if ev_args == t.__args__:
return t
return GenericAlias(t.__origin__, ev_args)
if isinstance(t, GenericAlias):
return GenericAlias(t.__origin__, ev_args)
else:
return t.copy_with(ev_args)
return t
@ -300,6 +273,7 @@ class _Final:
class _Immutable:
"""Mixin to indicate that object should not be copied."""
__slots__ = ()
def __copy__(self):
return self
@ -446,7 +420,7 @@ def Union(self, parameters):
parameters = _remove_dups_flatten(parameters)
if len(parameters) == 1:
return parameters[0]
return _GenericAlias(self, parameters)
return _UnionGenericAlias(self, parameters)
@_SpecialForm
def Optional(self, parameters):
@ -579,7 +553,7 @@ class TypeVar(_Final, _Immutable, _root=True):
"""
__slots__ = ('__name__', '__bound__', '__constraints__',
'__covariant__', '__contravariant__')
'__covariant__', '__contravariant__', '__dict__')
def __init__(self, name, *constraints, bound=None,
covariant=False, contravariant=False):
@ -629,23 +603,10 @@ class TypeVar(_Final, _Immutable, _root=True):
# e.g., Dict[T, int].__args__ == (T, int).
# Mapping from non-generic type names that have a generic alias in typing
# but with a different name.
_normalize_alias = {'list': 'List',
'tuple': 'Tuple',
'dict': 'Dict',
'set': 'Set',
'frozenset': 'FrozenSet',
'deque': 'Deque',
'defaultdict': 'DefaultDict',
'type': 'Type',
'Set': 'AbstractSet'}
def _is_dunder(attr):
return attr.startswith('__') and attr.endswith('__')
class _GenericAlias(_Final, _root=True):
class _BaseGenericAlias(_Final, _root=True):
"""The central part of internal API.
This represents a generic version of type 'origin' with type arguments 'params'.
@ -654,12 +615,8 @@ class _GenericAlias(_Final, _root=True):
have 'name' always set. If 'inst' is False, then the alias can't be instantiated,
this is used by e.g. typing.List and typing.Dict.
"""
def __init__(self, origin, params, *, inst=True, special=False, name=None):
def __init__(self, origin, params, *, inst=True, name=None):
self._inst = inst
self._special = special
if special and name is None:
orig_name = origin.__name__
name = _normalize_alias.get(orig_name, orig_name)
self._name = name
if not isinstance(params, tuple):
params = (params,)
@ -671,9 +628,61 @@ class _GenericAlias(_Final, _root=True):
self.__slots__ = None # This is not documented.
if not name:
self.__module__ = origin.__module__
if special:
self.__doc__ = f'A generic version of {origin.__module__}.{origin.__qualname__}'
def __eq__(self, other):
if not isinstance(other, _BaseGenericAlias):
return NotImplemented
return (self.__origin__ == other.__origin__
and self.__args__ == other.__args__)
def __hash__(self):
return hash((self.__origin__, self.__args__))
def __call__(self, *args, **kwargs):
if not self._inst:
raise TypeError(f"Type {self._name} cannot be instantiated; "
f"use {self.__origin__.__name__}() instead")
result = self.__origin__(*args, **kwargs)
try:
result.__orig_class__ = self
except AttributeError:
pass
return result
def __mro_entries__(self, bases):
res = []
if self.__origin__ not in bases:
res.append(self.__origin__)
i = bases.index(self)
for b in bases[i+1:]:
if isinstance(b, _BaseGenericAlias) or issubclass(b, Generic):
break
else:
res.append(Generic)
return tuple(res)
def __getattr__(self, attr):
# We are careful for copy and pickle.
# Also for simplicity we just don't relay all dunder names
if '__origin__' in self.__dict__ and not _is_dunder(attr):
return getattr(self.__origin__, attr)
raise AttributeError(attr)
def __setattr__(self, attr, val):
if _is_dunder(attr) or attr in ('_name', '_inst'):
super().__setattr__(attr, val)
else:
setattr(self.__origin__, attr, val)
def __instancecheck__(self, obj):
return self.__subclasscheck__(type(obj))
def __subclasscheck__(self, cls):
raise TypeError("Subscripted generics cannot be used with"
" class and instance checks")
class _GenericAlias(_BaseGenericAlias, _root=True):
@_tp_cache
def __getitem__(self, params):
if self.__origin__ in (Generic, Protocol):
@ -684,131 +693,109 @@ class _GenericAlias(_Final, _root=True):
msg = "Parameters to generic types must be types."
params = tuple(_type_check(p, msg) for p in params)
_check_generic(self, params)
return _subs_tvars(self, self.__parameters__, params)
subst = dict(zip(self.__parameters__, params))
new_args = []
for arg in self.__args__:
if isinstance(arg, TypeVar):
arg = subst[arg]
elif isinstance(arg, (_BaseGenericAlias, GenericAlias)):
subargs = tuple(subst[x] for x in arg.__parameters__)
arg = arg[subargs]
new_args.append(arg)
return self.copy_with(tuple(new_args))
def copy_with(self, params):
# We don't copy self._special.
return _GenericAlias(self.__origin__, params, name=self._name, inst=self._inst)
return self.__class__(self.__origin__, params, name=self._name, inst=self._inst)
def __repr__(self):
if (self.__origin__ == Union and len(self.__args__) == 2
and type(None) in self.__args__):
if self.__args__[0] is not type(None):
arg = self.__args__[0]
else:
arg = self.__args__[1]
return (f'typing.Optional[{_type_repr(arg)}]')
if (self._name != 'Callable' or
len(self.__args__) == 2 and self.__args__[0] is Ellipsis):
if self._name:
name = 'typing.' + self._name
else:
name = _type_repr(self.__origin__)
if not self._special:
args = f'[{", ".join([_type_repr(a) for a in self.__args__])}]'
else:
args = ''
return (f'{name}{args}')
if self._special:
return 'typing.Callable'
return (f'typing.Callable'
f'[[{", ".join([_type_repr(a) for a in self.__args__[:-1]])}], '
f'{_type_repr(self.__args__[-1])}]')
if self._name:
name = 'typing.' + self._name
else:
name = _type_repr(self.__origin__)
args = ", ".join([_type_repr(a) for a in self.__args__])
return f'{name}[{args}]'
def __eq__(self, other):
if not isinstance(other, _GenericAlias):
return NotImplemented
if self.__origin__ != other.__origin__:
return False
if self.__origin__ is Union and other.__origin__ is Union:
return frozenset(self.__args__) == frozenset(other.__args__)
return self.__args__ == other.__args__
def __hash__(self):
if self.__origin__ is Union:
return hash((Union, frozenset(self.__args__)))
return hash((self.__origin__, self.__args__))
def __call__(self, *args, **kwargs):
if not self._inst:
raise TypeError(f"Type {self._name} cannot be instantiated; "
f"use {self._name.lower()}() instead")
result = self.__origin__(*args, **kwargs)
try:
result.__orig_class__ = self
except AttributeError:
pass
return result
def __reduce__(self):
if self._name:
origin = globals()[self._name]
else:
origin = self.__origin__
args = tuple(self.__args__)
if len(args) == 1 and not isinstance(args[0], tuple):
args, = args
return operator.getitem, (origin, args)
def __mro_entries__(self, bases):
if self._name: # generic version of an ABC or built-in class
res = []
if self.__origin__ not in bases:
res.append(self.__origin__)
i = bases.index(self)
if not any(isinstance(b, _GenericAlias) or issubclass(b, Generic)
for b in bases[i+1:]):
res.append(Generic)
return tuple(res)
return super().__mro_entries__(bases)
if self.__origin__ is Generic:
if Protocol in bases:
return ()
i = bases.index(self)
for b in bases[i+1:]:
if isinstance(b, _GenericAlias) and b is not self:
if isinstance(b, _BaseGenericAlias) and b is not self:
return ()
return (self.__origin__,)
def __getattr__(self, attr):
# We are careful for copy and pickle.
# Also for simplicity we just don't relay all dunder names
if '__origin__' in self.__dict__ and not _is_dunder(attr):
return getattr(self.__origin__, attr)
raise AttributeError(attr)
def __setattr__(self, attr, val):
if _is_dunder(attr) or attr in ('_name', '_inst', '_special'):
super().__setattr__(attr, val)
else:
setattr(self.__origin__, attr, val)
class _SpecialGenericAlias(_BaseGenericAlias, _root=True):
def __init__(self, origin, params, *, inst=True, name=None):
if name is None:
name = origin.__name__
super().__init__(origin, params, inst=inst, name=name)
self.__doc__ = f'A generic version of {origin.__module__}.{origin.__qualname__}'
def __instancecheck__(self, obj):
return self.__subclasscheck__(type(obj))
@_tp_cache
def __getitem__(self, params):
if not isinstance(params, tuple):
params = (params,)
msg = "Parameters to generic types must be types."
params = tuple(_type_check(p, msg) for p in params)
_check_generic(self, params)
assert self.__args__ == self.__parameters__
return self.copy_with(params)
def copy_with(self, params):
return _GenericAlias(self.__origin__, params,
name=self._name, inst=self._inst)
def __repr__(self):
return 'typing.' + self._name
def __subclasscheck__(self, cls):
if self._special:
if not isinstance(cls, _GenericAlias):
return issubclass(cls, self.__origin__)
if cls._special:
return issubclass(cls.__origin__, self.__origin__)
raise TypeError("Subscripted generics cannot be used with"
" class and instance checks")
if isinstance(cls, _SpecialGenericAlias):
return issubclass(cls.__origin__, self.__origin__)
if not isinstance(cls, _GenericAlias):
return issubclass(cls, self.__origin__)
return super().__subclasscheck__(cls)
def __reduce__(self):
if self._special:
return self._name
if self._name:
origin = globals()[self._name]
else:
origin = self.__origin__
if (origin is Callable and
not (len(self.__args__) == 2 and self.__args__[0] is Ellipsis)):
args = list(self.__args__[:-1]), self.__args__[-1]
else:
args = tuple(self.__args__)
if len(args) == 1 and not isinstance(args[0], tuple):
args, = args
return operator.getitem, (origin, args)
return self._name
class _VariadicGenericAlias(_GenericAlias, _root=True):
"""Same as _GenericAlias above but for variadic aliases. Currently,
this is used only by special internal aliases: Tuple and Callable.
"""
class _CallableGenericAlias(_GenericAlias, _root=True):
def __repr__(self):
assert self._name == 'Callable'
if len(self.__args__) == 2 and self.__args__[0] is Ellipsis:
return super().__repr__()
return (f'typing.Callable'
f'[[{", ".join([_type_repr(a) for a in self.__args__[:-1]])}], '
f'{_type_repr(self.__args__[-1])}]')
def __reduce__(self):
args = self.__args__
if not (len(args) == 2 and args[0] is ...):
args = list(args[:-1]), args[-1]
return operator.getitem, (Callable, args)
class _CallableType(_SpecialGenericAlias, _root=True):
def copy_with(self, params):
return _CallableGenericAlias(self.__origin__, params,
name=self._name, inst=self._inst)
def __getitem__(self, params):
if self._name != 'Callable' or not self._special:
return self.__getitem_inner__(params)
if not isinstance(params, tuple) or len(params) != 2:
raise TypeError("Callable must be used as "
"Callable[[arg, ...], result].")
@ -824,29 +811,53 @@ class _VariadicGenericAlias(_GenericAlias, _root=True):
@_tp_cache
def __getitem_inner__(self, params):
if self.__origin__ is tuple and self._special:
if params == ():
return self.copy_with((_TypingEmpty,))
if not isinstance(params, tuple):
params = (params,)
if len(params) == 2 and params[1] is ...:
msg = "Tuple[t, ...]: t must be a type."
p = _type_check(params[0], msg)
return self.copy_with((p, _TypingEllipsis))
msg = "Tuple[t0, t1, ...]: each t must be a type."
params = tuple(_type_check(p, msg) for p in params)
return self.copy_with(params)
if self.__origin__ is collections.abc.Callable and self._special:
args, result = params
msg = "Callable[args, result]: result must be a type."
result = _type_check(result, msg)
if args is Ellipsis:
return self.copy_with((_TypingEllipsis, result))
msg = "Callable[[arg, ...], result]: each arg must be a type."
args = tuple(_type_check(arg, msg) for arg in args)
params = args + (result,)
return self.copy_with(params)
return super().__getitem__(params)
args, result = params
msg = "Callable[args, result]: result must be a type."
result = _type_check(result, msg)
if args is Ellipsis:
return self.copy_with((_TypingEllipsis, result))
msg = "Callable[[arg, ...], result]: each arg must be a type."
args = tuple(_type_check(arg, msg) for arg in args)
params = args + (result,)
return self.copy_with(params)
class _TupleType(_SpecialGenericAlias, _root=True):
@_tp_cache
def __getitem__(self, params):
if params == ():
return self.copy_with((_TypingEmpty,))
if not isinstance(params, tuple):
params = (params,)
if len(params) == 2 and params[1] is ...:
msg = "Tuple[t, ...]: t must be a type."
p = _type_check(params[0], msg)
return self.copy_with((p, _TypingEllipsis))
msg = "Tuple[t0, t1, ...]: each t must be a type."
params = tuple(_type_check(p, msg) for p in params)
return self.copy_with(params)
class _UnionGenericAlias(_GenericAlias, _root=True):
def copy_with(self, params):
return Union[params]
def __eq__(self, other):
if not isinstance(other, _UnionGenericAlias):
return NotImplemented
return set(self.__args__) == set(other.__args__)
def __hash__(self):
return hash(frozenset(self.__args__))
def __repr__(self):
args = self.__args__
if len(args) == 2:
if args[0] is type(None):
return f'typing.Optional[{_type_repr(args[1])}]'
elif args[1] is type(None):
return f'typing.Optional[{_type_repr(args[0])}]'
return super().__repr__()
class Generic:
@ -1162,9 +1173,8 @@ class _AnnotatedAlias(_GenericAlias, _root=True):
def __eq__(self, other):
if not isinstance(other, _AnnotatedAlias):
return NotImplemented
if self.__origin__ != other.__origin__:
return False
return self.__metadata__ == other.__metadata__
return (self.__origin__ == other.__origin__
and self.__metadata__ == other.__metadata__)
def __hash__(self):
return hash((self.__origin__, self.__metadata__))
@ -1380,9 +1390,7 @@ def _strip_annotations(t):
stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
if stripped_args == t.__args__:
return t
res = t.copy_with(stripped_args)
res._special = t._special
return res
return t.copy_with(stripped_args)
if isinstance(t, GenericAlias):
stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
if stripped_args == t.__args__:
@ -1407,7 +1415,7 @@ def get_origin(tp):
"""
if isinstance(tp, _AnnotatedAlias):
return Annotated
if isinstance(tp, (_GenericAlias, GenericAlias)):
if isinstance(tp, (_BaseGenericAlias, GenericAlias)):
return tp.__origin__
if tp is Generic:
return Generic
@ -1427,7 +1435,7 @@ def get_args(tp):
"""
if isinstance(tp, _AnnotatedAlias):
return (tp.__origin__,) + tp.__metadata__
if isinstance(tp, _GenericAlias) and not tp._special:
if isinstance(tp, _GenericAlias):
res = tp.__args__
if tp.__origin__ is collections.abc.Callable and res[0] is not Ellipsis:
res = (list(res[:-1]), res[-1])
@ -1561,8 +1569,7 @@ AnyStr = TypeVar('AnyStr', bytes, str)
# Various ABCs mimicking those in collections.abc.
def _alias(origin, params, inst=True):
return _GenericAlias(origin, params, special=True, inst=inst)
_alias = _SpecialGenericAlias
Hashable = _alias(collections.abc.Hashable, ()) # Not generic.
Awaitable = _alias(collections.abc.Awaitable, T_co)
@ -1575,7 +1582,7 @@ Reversible = _alias(collections.abc.Reversible, T_co)
Sized = _alias(collections.abc.Sized, ()) # Not generic.
Container = _alias(collections.abc.Container, T_co)
Collection = _alias(collections.abc.Collection, T_co)
Callable = _VariadicGenericAlias(collections.abc.Callable, (), special=True)
Callable = _CallableType(collections.abc.Callable, ())
Callable.__doc__ = \
"""Callable type; Callable[[int], str] is a function of (int) -> str.
@ -1586,7 +1593,7 @@ Callable.__doc__ = \
There is no syntax to indicate optional or keyword arguments,
such function types are rarely used as callback types.
"""
AbstractSet = _alias(collections.abc.Set, T_co)
AbstractSet = _alias(collections.abc.Set, T_co, name='AbstractSet')
MutableSet = _alias(collections.abc.MutableSet, T)
# NOTE: Mapping is only covariant in the value type.
Mapping = _alias(collections.abc.Mapping, (KT, VT_co))
@ -1594,7 +1601,7 @@ MutableMapping = _alias(collections.abc.MutableMapping, (KT, VT))
Sequence = _alias(collections.abc.Sequence, T_co)
MutableSequence = _alias(collections.abc.MutableSequence, T)
ByteString = _alias(collections.abc.ByteString, ()) # Not generic
Tuple = _VariadicGenericAlias(tuple, (), inst=False, special=True)
Tuple = _TupleType(tuple, (), inst=False, name='Tuple')
Tuple.__doc__ = \
"""Tuple type; Tuple[X, Y] is the cross-product type of X and Y.
@ -1604,24 +1611,24 @@ Tuple.__doc__ = \
To specify a variable-length tuple of homogeneous type, use Tuple[T, ...].
"""
List = _alias(list, T, inst=False)
Deque = _alias(collections.deque, T)
Set = _alias(set, T, inst=False)
FrozenSet = _alias(frozenset, T_co, inst=False)
List = _alias(list, T, inst=False, name='List')
Deque = _alias(collections.deque, T, name='Deque')
Set = _alias(set, T, inst=False, name='Set')
FrozenSet = _alias(frozenset, T_co, inst=False, name='FrozenSet')
MappingView = _alias(collections.abc.MappingView, T_co)
KeysView = _alias(collections.abc.KeysView, KT)
ItemsView = _alias(collections.abc.ItemsView, (KT, VT_co))
ValuesView = _alias(collections.abc.ValuesView, VT_co)
ContextManager = _alias(contextlib.AbstractContextManager, T_co)
AsyncContextManager = _alias(contextlib.AbstractAsyncContextManager, T_co)
Dict = _alias(dict, (KT, VT), inst=False)
DefaultDict = _alias(collections.defaultdict, (KT, VT))
ContextManager = _alias(contextlib.AbstractContextManager, T_co, name='ContextManager')
AsyncContextManager = _alias(contextlib.AbstractAsyncContextManager, T_co, name='AsyncContextManager')
Dict = _alias(dict, (KT, VT), inst=False, name='Dict')
DefaultDict = _alias(collections.defaultdict, (KT, VT), name='DefaultDict')
OrderedDict = _alias(collections.OrderedDict, (KT, VT))
Counter = _alias(collections.Counter, T)
ChainMap = _alias(collections.ChainMap, (KT, VT))
Generator = _alias(collections.abc.Generator, (T_co, T_contra, V_co))
AsyncGenerator = _alias(collections.abc.AsyncGenerator, (T_co, T_contra))
Type = _alias(type, CT_co, inst=False)
Type = _alias(type, CT_co, inst=False, name='Type')
Type.__doc__ = \
"""A special construct usable to annotate class objects.