bpo-42345: Fix three issues with typing.Literal parameters (GH-23294)

Literal equality no longer depends on the order of arguments.

Fix issue related to `typing.Literal` caching by adding `typed` parameter to `typing._tp_cache` function.

Add deduplication of `typing.Literal` arguments.
This commit is contained in:
Yurii Karabas 2020-11-17 04:23:19 +02:00 committed by GitHub
parent b0aba1fcdc
commit f03d318ca4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 104 additions and 23 deletions

View File

@ -528,6 +528,7 @@ class LiteralTests(BaseTestCase):
self.assertEqual(repr(Literal[int]), "typing.Literal[int]") self.assertEqual(repr(Literal[int]), "typing.Literal[int]")
self.assertEqual(repr(Literal), "typing.Literal") self.assertEqual(repr(Literal), "typing.Literal")
self.assertEqual(repr(Literal[None]), "typing.Literal[None]") self.assertEqual(repr(Literal[None]), "typing.Literal[None]")
self.assertEqual(repr(Literal[1, 2, 3, 3]), "typing.Literal[1, 2, 3]")
def test_cannot_init(self): def test_cannot_init(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
@ -559,6 +560,30 @@ class LiteralTests(BaseTestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
Literal[1][1] Literal[1][1]
def test_equal(self):
self.assertNotEqual(Literal[0], Literal[False])
self.assertNotEqual(Literal[True], Literal[1])
self.assertNotEqual(Literal[1], Literal[2])
self.assertNotEqual(Literal[1, True], Literal[1])
self.assertEqual(Literal[1], Literal[1])
self.assertEqual(Literal[1, 2], Literal[2, 1])
self.assertEqual(Literal[1, 2, 3], Literal[1, 2, 3, 3])
def test_args(self):
self.assertEqual(Literal[1, 2, 3].__args__, (1, 2, 3))
self.assertEqual(Literal[1, 2, 3, 3].__args__, (1, 2, 3))
self.assertEqual(Literal[1, Literal[2], Literal[3, 4]].__args__, (1, 2, 3, 4))
# Mutable arguments will not be deduplicated
self.assertEqual(Literal[[], []].__args__, ([], []))
def test_flatten(self):
l1 = Literal[Literal[1], Literal[2], Literal[3]]
l2 = Literal[Literal[1, 2], 3]
l3 = Literal[Literal[1, 2, 3]]
for l in l1, l2, l3:
self.assertEqual(l, Literal[1, 2, 3])
self.assertEqual(l.__args__, (1, 2, 3))
XK = TypeVar('XK', str, bytes) XK = TypeVar('XK', str, bytes)
XV = TypeVar('XV') XV = TypeVar('XV')

View File

@ -202,6 +202,20 @@ def _check_generic(cls, parameters, elen):
f" actual {alen}, expected {elen}") f" actual {alen}, expected {elen}")
def _deduplicate(params):
# Weed out strict duplicates, preserving the first of each occurrence.
all_params = set(params)
if len(all_params) < len(params):
new_params = []
for t in params:
if t in all_params:
new_params.append(t)
all_params.remove(t)
params = new_params
assert not all_params, all_params
return params
def _remove_dups_flatten(parameters): def _remove_dups_flatten(parameters):
"""An internal helper for Union creation and substitution: flatten Unions """An internal helper for Union creation and substitution: flatten Unions
among parameters, then remove duplicates. among parameters, then remove duplicates.
@ -215,38 +229,45 @@ def _remove_dups_flatten(parameters):
params.extend(p[1:]) params.extend(p[1:])
else: else:
params.append(p) params.append(p)
# Weed out strict duplicates, preserving the first of each occurrence.
all_params = set(params) return tuple(_deduplicate(params))
if len(all_params) < len(params):
new_params = []
for t in params: def _flatten_literal_params(parameters):
if t in all_params: """An internal helper for Literal creation: flatten Literals among parameters"""
new_params.append(t) params = []
all_params.remove(t) for p in parameters:
params = new_params if isinstance(p, _LiteralGenericAlias):
assert not all_params, all_params params.extend(p.__args__)
else:
params.append(p)
return tuple(params) return tuple(params)
_cleanups = [] _cleanups = []
def _tp_cache(func): def _tp_cache(func=None, /, *, typed=False):
"""Internal wrapper caching __getitem__ of generic types with a fallback to """Internal wrapper caching __getitem__ of generic types with a fallback to
original function for non-hashable arguments. original function for non-hashable arguments.
""" """
cached = functools.lru_cache()(func) def decorator(func):
_cleanups.append(cached.cache_clear) cached = functools.lru_cache(typed=typed)(func)
_cleanups.append(cached.cache_clear)
@functools.wraps(func) @functools.wraps(func)
def inner(*args, **kwds): def inner(*args, **kwds):
try: try:
return cached(*args, **kwds) return cached(*args, **kwds)
except TypeError: except TypeError:
pass # All real errors (not unhashable args) are raised below. pass # All real errors (not unhashable args) are raised below.
return func(*args, **kwds) return func(*args, **kwds)
return inner return inner
if func is not None:
return decorator(func)
return decorator
def _eval_type(t, globalns, localns, recursive_guard=frozenset()): def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
"""Evaluate all forward references in the given type t. """Evaluate all forward references in the given type t.
@ -319,6 +340,13 @@ class _SpecialForm(_Final, _root=True):
def __getitem__(self, parameters): def __getitem__(self, parameters):
return self._getitem(self, parameters) return self._getitem(self, parameters)
class _LiteralSpecialForm(_SpecialForm, _root=True):
@_tp_cache(typed=True)
def __getitem__(self, parameters):
return self._getitem(self, parameters)
@_SpecialForm @_SpecialForm
def Any(self, parameters): def Any(self, parameters):
"""Special type indicating an unconstrained type. """Special type indicating an unconstrained type.
@ -436,7 +464,7 @@ def Optional(self, parameters):
arg = _type_check(parameters, f"{self} requires a single type.") arg = _type_check(parameters, f"{self} requires a single type.")
return Union[arg, type(None)] return Union[arg, type(None)]
@_SpecialForm @_LiteralSpecialForm
def Literal(self, parameters): def Literal(self, parameters):
"""Special typing form to define literal types (a.k.a. value types). """Special typing form to define literal types (a.k.a. value types).
@ -460,7 +488,17 @@ def Literal(self, parameters):
""" """
# There is no '_type_check' call because arguments to Literal[...] are # There is no '_type_check' call because arguments to Literal[...] are
# values, not types. # values, not types.
return _GenericAlias(self, parameters) if not isinstance(parameters, tuple):
parameters = (parameters,)
parameters = _flatten_literal_params(parameters)
try:
parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))
except TypeError: # unhashable parameters
pass
return _LiteralGenericAlias(self, parameters)
@_SpecialForm @_SpecialForm
@ -930,6 +968,21 @@ class _UnionGenericAlias(_GenericAlias, _root=True):
return True return True
def _value_and_type_iter(parameters):
return ((p, type(p)) for p in parameters)
class _LiteralGenericAlias(_GenericAlias, _root=True):
def __eq__(self, other):
if not isinstance(other, _LiteralGenericAlias):
return NotImplemented
return set(_value_and_type_iter(self.__args__)) == set(_value_and_type_iter(other.__args__))
def __hash__(self):
return hash(tuple(_value_and_type_iter(self.__args__)))
class Generic: class Generic:
"""Abstract base class for generic types. """Abstract base class for generic types.

View File

@ -861,6 +861,7 @@ Jan Kanis
Rafe Kaplan Rafe Kaplan
Jacob Kaplan-Moss Jacob Kaplan-Moss
Allison Kaptur Allison Kaptur
Yurii Karabas
Janne Karila Janne Karila
Per Øyvind Karlsen Per Øyvind Karlsen
Anton Kasyanov Anton Kasyanov

View File

@ -0,0 +1,2 @@
Fix various issues with ``typing.Literal`` parameter handling (flatten,
deduplicate, use type to cache key). Patch provided by Yurii Karabas.