From f03d318ca42578e45405717aedd4ac26ea52aaed Mon Sep 17 00:00:00 2001 From: Yurii Karabas <1998uriyyo@gmail.com> Date: Tue, 17 Nov 2020 04:23:19 +0200 Subject: [PATCH] bpo-42345: Fix three issues with typing.Literal parameters (GH-23294) Literal equality no longer depends on the order of arguments. Fix issue related to `typing.Literal` caching by adding `typed` parameter to `typing._tp_cache` function. Add deduplication of `typing.Literal` arguments. --- Lib/test/test_typing.py | 25 +++++ Lib/typing.py | 99 ++++++++++++++----- Misc/ACKS | 1 + .../2020-11-15-15-23-34.bpo-42345.hiIR7x.rst | 2 + 4 files changed, 104 insertions(+), 23 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2020-11-15-15-23-34.bpo-42345.hiIR7x.rst diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 2ab8be49b28..7deba0d71b7 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -528,6 +528,7 @@ class LiteralTests(BaseTestCase): self.assertEqual(repr(Literal[int]), "typing.Literal[int]") self.assertEqual(repr(Literal), "typing.Literal") self.assertEqual(repr(Literal[None]), "typing.Literal[None]") + self.assertEqual(repr(Literal[1, 2, 3, 3]), "typing.Literal[1, 2, 3]") def test_cannot_init(self): with self.assertRaises(TypeError): @@ -559,6 +560,30 @@ class LiteralTests(BaseTestCase): with self.assertRaises(TypeError): Literal[1][1] + def test_equal(self): + self.assertNotEqual(Literal[0], Literal[False]) + self.assertNotEqual(Literal[True], Literal[1]) + self.assertNotEqual(Literal[1], Literal[2]) + self.assertNotEqual(Literal[1, True], Literal[1]) + self.assertEqual(Literal[1], Literal[1]) + self.assertEqual(Literal[1, 2], Literal[2, 1]) + self.assertEqual(Literal[1, 2, 3], Literal[1, 2, 3, 3]) + + def test_args(self): + self.assertEqual(Literal[1, 2, 3].__args__, (1, 2, 3)) + self.assertEqual(Literal[1, 2, 3, 3].__args__, (1, 2, 3)) + self.assertEqual(Literal[1, Literal[2], Literal[3, 4]].__args__, (1, 2, 3, 4)) + # Mutable arguments will not be deduplicated + self.assertEqual(Literal[[], []].__args__, ([], [])) + + def test_flatten(self): + l1 = Literal[Literal[1], Literal[2], Literal[3]] + l2 = Literal[Literal[1, 2], 3] + l3 = Literal[Literal[1, 2, 3]] + for l in l1, l2, l3: + self.assertEqual(l, Literal[1, 2, 3]) + self.assertEqual(l.__args__, (1, 2, 3)) + XK = TypeVar('XK', str, bytes) XV = TypeVar('XV') diff --git a/Lib/typing.py b/Lib/typing.py index 3fa97a4a15f..d310b3dd582 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -202,6 +202,20 @@ def _check_generic(cls, parameters, elen): f" actual {alen}, expected {elen}") +def _deduplicate(params): + # Weed out strict duplicates, preserving the first of each occurrence. + all_params = set(params) + if len(all_params) < len(params): + new_params = [] + for t in params: + if t in all_params: + new_params.append(t) + all_params.remove(t) + params = new_params + assert not all_params, all_params + return params + + def _remove_dups_flatten(parameters): """An internal helper for Union creation and substitution: flatten Unions among parameters, then remove duplicates. @@ -215,38 +229,45 @@ def _remove_dups_flatten(parameters): params.extend(p[1:]) else: params.append(p) - # Weed out strict duplicates, preserving the first of each occurrence. - all_params = set(params) - if len(all_params) < len(params): - new_params = [] - for t in params: - if t in all_params: - new_params.append(t) - all_params.remove(t) - params = new_params - assert not all_params, all_params + + return tuple(_deduplicate(params)) + + +def _flatten_literal_params(parameters): + """An internal helper for Literal creation: flatten Literals among parameters""" + params = [] + for p in parameters: + if isinstance(p, _LiteralGenericAlias): + params.extend(p.__args__) + else: + params.append(p) return tuple(params) _cleanups = [] -def _tp_cache(func): +def _tp_cache(func=None, /, *, typed=False): """Internal wrapper caching __getitem__ of generic types with a fallback to original function for non-hashable arguments. """ - cached = functools.lru_cache()(func) - _cleanups.append(cached.cache_clear) + def decorator(func): + cached = functools.lru_cache(typed=typed)(func) + _cleanups.append(cached.cache_clear) - @functools.wraps(func) - def inner(*args, **kwds): - try: - return cached(*args, **kwds) - except TypeError: - pass # All real errors (not unhashable args) are raised below. - return func(*args, **kwds) - return inner + @functools.wraps(func) + def inner(*args, **kwds): + try: + return cached(*args, **kwds) + except TypeError: + pass # All real errors (not unhashable args) are raised below. + return func(*args, **kwds) + return inner + if func is not None: + return decorator(func) + + return decorator def _eval_type(t, globalns, localns, recursive_guard=frozenset()): """Evaluate all forward references in the given type t. @@ -319,6 +340,13 @@ class _SpecialForm(_Final, _root=True): def __getitem__(self, parameters): return self._getitem(self, parameters) + +class _LiteralSpecialForm(_SpecialForm, _root=True): + @_tp_cache(typed=True) + def __getitem__(self, parameters): + return self._getitem(self, parameters) + + @_SpecialForm def Any(self, parameters): """Special type indicating an unconstrained type. @@ -436,7 +464,7 @@ def Optional(self, parameters): arg = _type_check(parameters, f"{self} requires a single type.") return Union[arg, type(None)] -@_SpecialForm +@_LiteralSpecialForm def Literal(self, parameters): """Special typing form to define literal types (a.k.a. value types). @@ -460,7 +488,17 @@ def Literal(self, parameters): """ # There is no '_type_check' call because arguments to Literal[...] are # values, not types. - return _GenericAlias(self, parameters) + if not isinstance(parameters, tuple): + parameters = (parameters,) + + parameters = _flatten_literal_params(parameters) + + try: + parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters)))) + except TypeError: # unhashable parameters + pass + + return _LiteralGenericAlias(self, parameters) @_SpecialForm @@ -930,6 +968,21 @@ class _UnionGenericAlias(_GenericAlias, _root=True): return True +def _value_and_type_iter(parameters): + return ((p, type(p)) for p in parameters) + + +class _LiteralGenericAlias(_GenericAlias, _root=True): + + def __eq__(self, other): + if not isinstance(other, _LiteralGenericAlias): + return NotImplemented + + return set(_value_and_type_iter(self.__args__)) == set(_value_and_type_iter(other.__args__)) + + def __hash__(self): + return hash(tuple(_value_and_type_iter(self.__args__))) + class Generic: """Abstract base class for generic types. diff --git a/Misc/ACKS b/Misc/ACKS index 35a87ae6b96..1d106144d46 100644 --- a/Misc/ACKS +++ b/Misc/ACKS @@ -861,6 +861,7 @@ Jan Kanis Rafe Kaplan Jacob Kaplan-Moss Allison Kaptur +Yurii Karabas Janne Karila Per Øyvind Karlsen Anton Kasyanov diff --git a/Misc/NEWS.d/next/Library/2020-11-15-15-23-34.bpo-42345.hiIR7x.rst b/Misc/NEWS.d/next/Library/2020-11-15-15-23-34.bpo-42345.hiIR7x.rst new file mode 100644 index 00000000000..6339182c3ae --- /dev/null +++ b/Misc/NEWS.d/next/Library/2020-11-15-15-23-34.bpo-42345.hiIR7x.rst @@ -0,0 +1,2 @@ +Fix various issues with ``typing.Literal`` parameter handling (flatten, +deduplicate, use type to cache key). Patch provided by Yurii Karabas.