diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index d940f1aaaf3..2d0e33f03c3 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -18,6 +18,11 @@ class Example: class Forward: ... +def clear_typing_caches(): + for f in typing._cleanups: + f() + + class TypesTests(unittest.TestCase): def test_truth_values(self): @@ -710,11 +715,34 @@ class TypesTests(unittest.TestCase): self.assertIs((TV | int)[int], int) def test_union_args(self): - self.assertEqual((int | str).__args__, (int, str)) - self.assertEqual(((int | str) | list).__args__, (int, str, list)) - self.assertEqual((int | (str | list)).__args__, (int, str, list)) - self.assertEqual((int | None).__args__, (int, type(None))) - self.assertEqual((int | type(None)).__args__, (int, type(None))) + def check(arg, expected): + clear_typing_caches() + self.assertEqual(arg.__args__, expected) + + check(int | str, (int, str)) + check((int | str) | list, (int, str, list)) + check(int | (str | list), (int, str, list)) + check((int | str) | int, (int, str)) + check(int | (str | int), (int, str)) + check((int | str) | (str | int), (int, str)) + check(typing.Union[int, str] | list, (int, str, list)) + check(int | typing.Union[str, list], (int, str, list)) + check((int | str) | (list | int), (int, str, list)) + check((int | str) | typing.Union[list, int], (int, str, list)) + check(typing.Union[int, str] | (list | int), (int, str, list)) + check((str | int) | (int | list), (str, int, list)) + check((str | int) | typing.Union[int, list], (str, int, list)) + check(typing.Union[str, int] | (int | list), (str, int, list)) + check(int | type(None), (int, type(None))) + check(type(None) | int, (type(None), int)) + + args = (int, list[int], typing.List[int], + typing.Tuple[int, int], typing.Callable[[int], int], + typing.Hashable, typing.TypeVar('T')) + for x in args: + with self.subTest(x): + check(x | None, (x, type(None))) + check(None | x, (type(None), x)) def test_union_parameter_chaining(self): T = typing.TypeVar("T") diff --git a/Lib/typing.py b/Lib/typing.py index ca05fb54bf4..2f2286813a4 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -988,8 +988,8 @@ class _GenericAlias(_BaseGenericAlias, _root=True): def __or__(self, right): return Union[self, right] - def __ror__(self, right): - return Union[self, right] + def __ror__(self, left): + return Union[left, self] @_tp_cache def __getitem__(self, params): @@ -1099,8 +1099,8 @@ class _SpecialGenericAlias(_BaseGenericAlias, _root=True): def __or__(self, right): return Union[self, right] - def __ror__(self, right): - return Union[self, right] + def __ror__(self, left): + return Union[left, self] class _CallableGenericAlias(_GenericAlias, _root=True): def __repr__(self): diff --git a/Objects/unionobject.c b/Objects/unionobject.c index dad26c32b29..b3a65068626 100644 --- a/Objects/unionobject.c +++ b/Objects/unionobject.c @@ -260,8 +260,8 @@ dedup_and_flatten_args(PyObject* args) for (Py_ssize_t i = 0; i < arg_length; i++) { int is_duplicate = 0; PyObject* i_element = PyTuple_GET_ITEM(args, i); - for (Py_ssize_t j = i + 1; j < arg_length; j++) { - PyObject* j_element = PyTuple_GET_ITEM(args, j); + for (Py_ssize_t j = 0; j < added_items; j++) { + PyObject* j_element = PyTuple_GET_ITEM(new_args, j); int is_ga = PyObject_TypeCheck(i_element, &Py_GenericAliasType) && PyObject_TypeCheck(j_element, &Py_GenericAliasType); // RichCompare to also deduplicate GenericAlias types (slower)