diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index e3c06561577..382e7dbffdd 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -2792,6 +2792,49 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(f(1), "types.UnionType") self.assertEqual(f(1.0), "types.UnionType") + def test_union_conflict(self): + @functools.singledispatch + def f(arg): + return "default" + + @f.register + def _(arg: typing.Union[str, bytes]): + return "typing.Union" + + @f.register + def _(arg: int | str): + return "types.UnionType" + + self.assertEqual(f([]), "default") + self.assertEqual(f(""), "types.UnionType") # last one wins + self.assertEqual(f(b""), "typing.Union") + self.assertEqual(f(1), "types.UnionType") + + def test_union_None(self): + @functools.singledispatch + def typing_union(arg): + return "default" + + @typing_union.register + def _(arg: typing.Union[str, None]): + return "typing.Union" + + self.assertEqual(typing_union(1), "default") + self.assertEqual(typing_union(""), "typing.Union") + self.assertEqual(typing_union(None), "typing.Union") + + @functools.singledispatch + def types_union(arg): + return "default" + + @types_union.register + def _(arg: int | None): + return "types.UnionType" + + self.assertEqual(types_union(""), "default") + self.assertEqual(types_union(1), "types.UnionType") + self.assertEqual(types_union(None), "types.UnionType") + def test_register_genericalias(self): @functools.singledispatch def f(arg):