bpo-42345: Fix three issues with typing.Literal parameters (GH-23294)
Literal equality no longer depends on the order of arguments. Fix issue related to `typing.Literal` caching by adding `typed` parameter to `typing._tp_cache` function. Add deduplication of `typing.Literal` arguments.
This commit is contained in:
parent
b0aba1fcdc
commit
f03d318ca4
|
@ -528,6 +528,7 @@ class LiteralTests(BaseTestCase):
|
|||
self.assertEqual(repr(Literal[int]), "typing.Literal[int]")
|
||||
self.assertEqual(repr(Literal), "typing.Literal")
|
||||
self.assertEqual(repr(Literal[None]), "typing.Literal[None]")
|
||||
self.assertEqual(repr(Literal[1, 2, 3, 3]), "typing.Literal[1, 2, 3]")
|
||||
|
||||
def test_cannot_init(self):
|
||||
with self.assertRaises(TypeError):
|
||||
|
@ -559,6 +560,30 @@ class LiteralTests(BaseTestCase):
|
|||
with self.assertRaises(TypeError):
|
||||
Literal[1][1]
|
||||
|
||||
def test_equal(self):
|
||||
self.assertNotEqual(Literal[0], Literal[False])
|
||||
self.assertNotEqual(Literal[True], Literal[1])
|
||||
self.assertNotEqual(Literal[1], Literal[2])
|
||||
self.assertNotEqual(Literal[1, True], Literal[1])
|
||||
self.assertEqual(Literal[1], Literal[1])
|
||||
self.assertEqual(Literal[1, 2], Literal[2, 1])
|
||||
self.assertEqual(Literal[1, 2, 3], Literal[1, 2, 3, 3])
|
||||
|
||||
def test_args(self):
|
||||
self.assertEqual(Literal[1, 2, 3].__args__, (1, 2, 3))
|
||||
self.assertEqual(Literal[1, 2, 3, 3].__args__, (1, 2, 3))
|
||||
self.assertEqual(Literal[1, Literal[2], Literal[3, 4]].__args__, (1, 2, 3, 4))
|
||||
# Mutable arguments will not be deduplicated
|
||||
self.assertEqual(Literal[[], []].__args__, ([], []))
|
||||
|
||||
def test_flatten(self):
|
||||
l1 = Literal[Literal[1], Literal[2], Literal[3]]
|
||||
l2 = Literal[Literal[1, 2], 3]
|
||||
l3 = Literal[Literal[1, 2, 3]]
|
||||
for l in l1, l2, l3:
|
||||
self.assertEqual(l, Literal[1, 2, 3])
|
||||
self.assertEqual(l.__args__, (1, 2, 3))
|
||||
|
||||
|
||||
XK = TypeVar('XK', str, bytes)
|
||||
XV = TypeVar('XV')
|
||||
|
|
|
@ -202,6 +202,20 @@ def _check_generic(cls, parameters, elen):
|
|||
f" actual {alen}, expected {elen}")
|
||||
|
||||
|
||||
def _deduplicate(params):
|
||||
# Weed out strict duplicates, preserving the first of each occurrence.
|
||||
all_params = set(params)
|
||||
if len(all_params) < len(params):
|
||||
new_params = []
|
||||
for t in params:
|
||||
if t in all_params:
|
||||
new_params.append(t)
|
||||
all_params.remove(t)
|
||||
params = new_params
|
||||
assert not all_params, all_params
|
||||
return params
|
||||
|
||||
|
||||
def _remove_dups_flatten(parameters):
|
||||
"""An internal helper for Union creation and substitution: flatten Unions
|
||||
among parameters, then remove duplicates.
|
||||
|
@ -215,38 +229,45 @@ def _remove_dups_flatten(parameters):
|
|||
params.extend(p[1:])
|
||||
else:
|
||||
params.append(p)
|
||||
# Weed out strict duplicates, preserving the first of each occurrence.
|
||||
all_params = set(params)
|
||||
if len(all_params) < len(params):
|
||||
new_params = []
|
||||
for t in params:
|
||||
if t in all_params:
|
||||
new_params.append(t)
|
||||
all_params.remove(t)
|
||||
params = new_params
|
||||
assert not all_params, all_params
|
||||
|
||||
return tuple(_deduplicate(params))
|
||||
|
||||
|
||||
def _flatten_literal_params(parameters):
|
||||
"""An internal helper for Literal creation: flatten Literals among parameters"""
|
||||
params = []
|
||||
for p in parameters:
|
||||
if isinstance(p, _LiteralGenericAlias):
|
||||
params.extend(p.__args__)
|
||||
else:
|
||||
params.append(p)
|
||||
return tuple(params)
|
||||
|
||||
|
||||
_cleanups = []
|
||||
|
||||
|
||||
def _tp_cache(func):
|
||||
def _tp_cache(func=None, /, *, typed=False):
|
||||
"""Internal wrapper caching __getitem__ of generic types with a fallback to
|
||||
original function for non-hashable arguments.
|
||||
"""
|
||||
cached = functools.lru_cache()(func)
|
||||
_cleanups.append(cached.cache_clear)
|
||||
def decorator(func):
|
||||
cached = functools.lru_cache(typed=typed)(func)
|
||||
_cleanups.append(cached.cache_clear)
|
||||
|
||||
@functools.wraps(func)
|
||||
def inner(*args, **kwds):
|
||||
try:
|
||||
return cached(*args, **kwds)
|
||||
except TypeError:
|
||||
pass # All real errors (not unhashable args) are raised below.
|
||||
return func(*args, **kwds)
|
||||
return inner
|
||||
@functools.wraps(func)
|
||||
def inner(*args, **kwds):
|
||||
try:
|
||||
return cached(*args, **kwds)
|
||||
except TypeError:
|
||||
pass # All real errors (not unhashable args) are raised below.
|
||||
return func(*args, **kwds)
|
||||
return inner
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
|
||||
return decorator
|
||||
|
||||
def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
|
||||
"""Evaluate all forward references in the given type t.
|
||||
|
@ -319,6 +340,13 @@ class _SpecialForm(_Final, _root=True):
|
|||
def __getitem__(self, parameters):
|
||||
return self._getitem(self, parameters)
|
||||
|
||||
|
||||
class _LiteralSpecialForm(_SpecialForm, _root=True):
|
||||
@_tp_cache(typed=True)
|
||||
def __getitem__(self, parameters):
|
||||
return self._getitem(self, parameters)
|
||||
|
||||
|
||||
@_SpecialForm
|
||||
def Any(self, parameters):
|
||||
"""Special type indicating an unconstrained type.
|
||||
|
@ -436,7 +464,7 @@ def Optional(self, parameters):
|
|||
arg = _type_check(parameters, f"{self} requires a single type.")
|
||||
return Union[arg, type(None)]
|
||||
|
||||
@_SpecialForm
|
||||
@_LiteralSpecialForm
|
||||
def Literal(self, parameters):
|
||||
"""Special typing form to define literal types (a.k.a. value types).
|
||||
|
||||
|
@ -460,7 +488,17 @@ def Literal(self, parameters):
|
|||
"""
|
||||
# There is no '_type_check' call because arguments to Literal[...] are
|
||||
# values, not types.
|
||||
return _GenericAlias(self, parameters)
|
||||
if not isinstance(parameters, tuple):
|
||||
parameters = (parameters,)
|
||||
|
||||
parameters = _flatten_literal_params(parameters)
|
||||
|
||||
try:
|
||||
parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))
|
||||
except TypeError: # unhashable parameters
|
||||
pass
|
||||
|
||||
return _LiteralGenericAlias(self, parameters)
|
||||
|
||||
|
||||
@_SpecialForm
|
||||
|
@ -930,6 +968,21 @@ class _UnionGenericAlias(_GenericAlias, _root=True):
|
|||
return True
|
||||
|
||||
|
||||
def _value_and_type_iter(parameters):
|
||||
return ((p, type(p)) for p in parameters)
|
||||
|
||||
|
||||
class _LiteralGenericAlias(_GenericAlias, _root=True):
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, _LiteralGenericAlias):
|
||||
return NotImplemented
|
||||
|
||||
return set(_value_and_type_iter(self.__args__)) == set(_value_and_type_iter(other.__args__))
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple(_value_and_type_iter(self.__args__)))
|
||||
|
||||
|
||||
class Generic:
|
||||
"""Abstract base class for generic types.
|
||||
|
|
|
@ -861,6 +861,7 @@ Jan Kanis
|
|||
Rafe Kaplan
|
||||
Jacob Kaplan-Moss
|
||||
Allison Kaptur
|
||||
Yurii Karabas
|
||||
Janne Karila
|
||||
Per Øyvind Karlsen
|
||||
Anton Kasyanov
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
Fix various issues with ``typing.Literal`` parameter handling (flatten,
|
||||
deduplicate, use type to cache key). Patch provided by Yurii Karabas.
|
Loading…
Reference in New Issue