bpo-42345: Fix three issues with typing.Literal parameters (GH-23294)
Literal equality no longer depends on the order of arguments. Fix issue related to `typing.Literal` caching by adding `typed` parameter to `typing._tp_cache` function. Add deduplication of `typing.Literal` arguments.
This commit is contained in:
parent
b0aba1fcdc
commit
f03d318ca4
|
@ -528,6 +528,7 @@ class LiteralTests(BaseTestCase):
|
||||||
self.assertEqual(repr(Literal[int]), "typing.Literal[int]")
|
self.assertEqual(repr(Literal[int]), "typing.Literal[int]")
|
||||||
self.assertEqual(repr(Literal), "typing.Literal")
|
self.assertEqual(repr(Literal), "typing.Literal")
|
||||||
self.assertEqual(repr(Literal[None]), "typing.Literal[None]")
|
self.assertEqual(repr(Literal[None]), "typing.Literal[None]")
|
||||||
|
self.assertEqual(repr(Literal[1, 2, 3, 3]), "typing.Literal[1, 2, 3]")
|
||||||
|
|
||||||
def test_cannot_init(self):
|
def test_cannot_init(self):
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
|
@ -559,6 +560,30 @@ class LiteralTests(BaseTestCase):
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
Literal[1][1]
|
Literal[1][1]
|
||||||
|
|
||||||
|
def test_equal(self):
|
||||||
|
self.assertNotEqual(Literal[0], Literal[False])
|
||||||
|
self.assertNotEqual(Literal[True], Literal[1])
|
||||||
|
self.assertNotEqual(Literal[1], Literal[2])
|
||||||
|
self.assertNotEqual(Literal[1, True], Literal[1])
|
||||||
|
self.assertEqual(Literal[1], Literal[1])
|
||||||
|
self.assertEqual(Literal[1, 2], Literal[2, 1])
|
||||||
|
self.assertEqual(Literal[1, 2, 3], Literal[1, 2, 3, 3])
|
||||||
|
|
||||||
|
def test_args(self):
|
||||||
|
self.assertEqual(Literal[1, 2, 3].__args__, (1, 2, 3))
|
||||||
|
self.assertEqual(Literal[1, 2, 3, 3].__args__, (1, 2, 3))
|
||||||
|
self.assertEqual(Literal[1, Literal[2], Literal[3, 4]].__args__, (1, 2, 3, 4))
|
||||||
|
# Mutable arguments will not be deduplicated
|
||||||
|
self.assertEqual(Literal[[], []].__args__, ([], []))
|
||||||
|
|
||||||
|
def test_flatten(self):
|
||||||
|
l1 = Literal[Literal[1], Literal[2], Literal[3]]
|
||||||
|
l2 = Literal[Literal[1, 2], 3]
|
||||||
|
l3 = Literal[Literal[1, 2, 3]]
|
||||||
|
for l in l1, l2, l3:
|
||||||
|
self.assertEqual(l, Literal[1, 2, 3])
|
||||||
|
self.assertEqual(l.__args__, (1, 2, 3))
|
||||||
|
|
||||||
|
|
||||||
XK = TypeVar('XK', str, bytes)
|
XK = TypeVar('XK', str, bytes)
|
||||||
XV = TypeVar('XV')
|
XV = TypeVar('XV')
|
||||||
|
|
|
@ -202,6 +202,20 @@ def _check_generic(cls, parameters, elen):
|
||||||
f" actual {alen}, expected {elen}")
|
f" actual {alen}, expected {elen}")
|
||||||
|
|
||||||
|
|
||||||
|
def _deduplicate(params):
|
||||||
|
# Weed out strict duplicates, preserving the first of each occurrence.
|
||||||
|
all_params = set(params)
|
||||||
|
if len(all_params) < len(params):
|
||||||
|
new_params = []
|
||||||
|
for t in params:
|
||||||
|
if t in all_params:
|
||||||
|
new_params.append(t)
|
||||||
|
all_params.remove(t)
|
||||||
|
params = new_params
|
||||||
|
assert not all_params, all_params
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
def _remove_dups_flatten(parameters):
|
def _remove_dups_flatten(parameters):
|
||||||
"""An internal helper for Union creation and substitution: flatten Unions
|
"""An internal helper for Union creation and substitution: flatten Unions
|
||||||
among parameters, then remove duplicates.
|
among parameters, then remove duplicates.
|
||||||
|
@ -215,38 +229,45 @@ def _remove_dups_flatten(parameters):
|
||||||
params.extend(p[1:])
|
params.extend(p[1:])
|
||||||
else:
|
else:
|
||||||
params.append(p)
|
params.append(p)
|
||||||
# Weed out strict duplicates, preserving the first of each occurrence.
|
|
||||||
all_params = set(params)
|
return tuple(_deduplicate(params))
|
||||||
if len(all_params) < len(params):
|
|
||||||
new_params = []
|
|
||||||
for t in params:
|
def _flatten_literal_params(parameters):
|
||||||
if t in all_params:
|
"""An internal helper for Literal creation: flatten Literals among parameters"""
|
||||||
new_params.append(t)
|
params = []
|
||||||
all_params.remove(t)
|
for p in parameters:
|
||||||
params = new_params
|
if isinstance(p, _LiteralGenericAlias):
|
||||||
assert not all_params, all_params
|
params.extend(p.__args__)
|
||||||
|
else:
|
||||||
|
params.append(p)
|
||||||
return tuple(params)
|
return tuple(params)
|
||||||
|
|
||||||
|
|
||||||
_cleanups = []
|
_cleanups = []
|
||||||
|
|
||||||
|
|
||||||
def _tp_cache(func):
|
def _tp_cache(func=None, /, *, typed=False):
|
||||||
"""Internal wrapper caching __getitem__ of generic types with a fallback to
|
"""Internal wrapper caching __getitem__ of generic types with a fallback to
|
||||||
original function for non-hashable arguments.
|
original function for non-hashable arguments.
|
||||||
"""
|
"""
|
||||||
cached = functools.lru_cache()(func)
|
def decorator(func):
|
||||||
_cleanups.append(cached.cache_clear)
|
cached = functools.lru_cache(typed=typed)(func)
|
||||||
|
_cleanups.append(cached.cache_clear)
|
||||||
|
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def inner(*args, **kwds):
|
def inner(*args, **kwds):
|
||||||
try:
|
try:
|
||||||
return cached(*args, **kwds)
|
return cached(*args, **kwds)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
pass # All real errors (not unhashable args) are raised below.
|
pass # All real errors (not unhashable args) are raised below.
|
||||||
return func(*args, **kwds)
|
return func(*args, **kwds)
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
if func is not None:
|
||||||
|
return decorator(func)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
|
def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
|
||||||
"""Evaluate all forward references in the given type t.
|
"""Evaluate all forward references in the given type t.
|
||||||
|
@ -319,6 +340,13 @@ class _SpecialForm(_Final, _root=True):
|
||||||
def __getitem__(self, parameters):
|
def __getitem__(self, parameters):
|
||||||
return self._getitem(self, parameters)
|
return self._getitem(self, parameters)
|
||||||
|
|
||||||
|
|
||||||
|
class _LiteralSpecialForm(_SpecialForm, _root=True):
|
||||||
|
@_tp_cache(typed=True)
|
||||||
|
def __getitem__(self, parameters):
|
||||||
|
return self._getitem(self, parameters)
|
||||||
|
|
||||||
|
|
||||||
@_SpecialForm
|
@_SpecialForm
|
||||||
def Any(self, parameters):
|
def Any(self, parameters):
|
||||||
"""Special type indicating an unconstrained type.
|
"""Special type indicating an unconstrained type.
|
||||||
|
@ -436,7 +464,7 @@ def Optional(self, parameters):
|
||||||
arg = _type_check(parameters, f"{self} requires a single type.")
|
arg = _type_check(parameters, f"{self} requires a single type.")
|
||||||
return Union[arg, type(None)]
|
return Union[arg, type(None)]
|
||||||
|
|
||||||
@_SpecialForm
|
@_LiteralSpecialForm
|
||||||
def Literal(self, parameters):
|
def Literal(self, parameters):
|
||||||
"""Special typing form to define literal types (a.k.a. value types).
|
"""Special typing form to define literal types (a.k.a. value types).
|
||||||
|
|
||||||
|
@ -460,7 +488,17 @@ def Literal(self, parameters):
|
||||||
"""
|
"""
|
||||||
# There is no '_type_check' call because arguments to Literal[...] are
|
# There is no '_type_check' call because arguments to Literal[...] are
|
||||||
# values, not types.
|
# values, not types.
|
||||||
return _GenericAlias(self, parameters)
|
if not isinstance(parameters, tuple):
|
||||||
|
parameters = (parameters,)
|
||||||
|
|
||||||
|
parameters = _flatten_literal_params(parameters)
|
||||||
|
|
||||||
|
try:
|
||||||
|
parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))
|
||||||
|
except TypeError: # unhashable parameters
|
||||||
|
pass
|
||||||
|
|
||||||
|
return _LiteralGenericAlias(self, parameters)
|
||||||
|
|
||||||
|
|
||||||
@_SpecialForm
|
@_SpecialForm
|
||||||
|
@ -930,6 +968,21 @@ class _UnionGenericAlias(_GenericAlias, _root=True):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _value_and_type_iter(parameters):
|
||||||
|
return ((p, type(p)) for p in parameters)
|
||||||
|
|
||||||
|
|
||||||
|
class _LiteralGenericAlias(_GenericAlias, _root=True):
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, _LiteralGenericAlias):
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
return set(_value_and_type_iter(self.__args__)) == set(_value_and_type_iter(other.__args__))
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(tuple(_value_and_type_iter(self.__args__)))
|
||||||
|
|
||||||
|
|
||||||
class Generic:
|
class Generic:
|
||||||
"""Abstract base class for generic types.
|
"""Abstract base class for generic types.
|
||||||
|
|
|
@ -861,6 +861,7 @@ Jan Kanis
|
||||||
Rafe Kaplan
|
Rafe Kaplan
|
||||||
Jacob Kaplan-Moss
|
Jacob Kaplan-Moss
|
||||||
Allison Kaptur
|
Allison Kaptur
|
||||||
|
Yurii Karabas
|
||||||
Janne Karila
|
Janne Karila
|
||||||
Per Øyvind Karlsen
|
Per Øyvind Karlsen
|
||||||
Anton Kasyanov
|
Anton Kasyanov
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
Fix various issues with ``typing.Literal`` parameter handling (flatten,
|
||||||
|
deduplicate, use type to cache key). Patch provided by Yurii Karabas.
|
Loading…
Reference in New Issue