diff --git a/Include/internal/pycore_unionobject.h b/Include/internal/pycore_unionobject.h index fa8ba6ed944..4d82b6fbeae 100644 --- a/Include/internal/pycore_unionobject.h +++ b/Include/internal/pycore_unionobject.h @@ -10,6 +10,7 @@ extern "C" { PyAPI_FUNC(PyObject *) _Py_Union(PyObject *args); PyAPI_DATA(PyTypeObject) _Py_UnionType; +PyAPI_FUNC(PyObject *) _Py_union_type_or(PyObject* self, PyObject* param); #ifdef __cplusplus } diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index 75c5eee42dc..3058a02d6ee 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -713,6 +713,28 @@ class TypesTests(unittest.TestCase): assert repr(int | None) == "int | None" assert repr(int | typing.GenericAlias(list, int)) == "int | list[int]" + def test_or_type_operator_with_genericalias(self): + a = list[int] + b = list[str] + c = dict[float, str] + # equivalence with typing.Union + self.assertEqual(a | b | c, typing.Union[a, b, c]) + # de-duplicate + self.assertEqual(a | c | b | b | a | c, a | b | c) + # order shouldn't matter + self.assertEqual(a | b, b | a) + self.assertEqual(repr(a | b | c), + "list[int] | list[str] | dict[float, str]") + + class BadType(type): + def __eq__(self, other): + return 1 / 0 + + bt = BadType('bt', (), {}) + # Comparison should fail and errors should propagate out for bad types. + with self.assertRaises(ZeroDivisionError): + list[int] | list[bt] + def test_ellipsis_type(self): self.assertIsInstance(Ellipsis, types.EllipsisType) diff --git a/Misc/NEWS.d/next/Core and Builtins/2020-11-01-23-34-56.bpo-42233.zOSzja.rst b/Misc/NEWS.d/next/Core and Builtins/2020-11-01-23-34-56.bpo-42233.zOSzja.rst new file mode 100644 index 00000000000..499bb324fb9 --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2020-11-01-23-34-56.bpo-42233.zOSzja.rst @@ -0,0 +1,5 @@ +Allow ``GenericAlias`` objects to use :ref:`union type expressions `. +This allows expressions like ``list[int] | dict[float, str]`` where previously a +``TypeError`` would have been thrown. This also fixes union type expressions +not de-duplicating ``GenericAlias`` objects. (Contributed by Ken Jin in +:issue:`42233`.) diff --git a/Objects/genericaliasobject.c b/Objects/genericaliasobject.c index 6508c69cbf7..28ea487a44f 100644 --- a/Objects/genericaliasobject.c +++ b/Objects/genericaliasobject.c @@ -2,6 +2,7 @@ #include "Python.h" #include "pycore_object.h" +#include "pycore_unionobject.h" // _Py_union_as_number #include "structmember.h" // PyMemberDef typedef struct { @@ -573,6 +574,10 @@ ga_new(PyTypeObject *type, PyObject *args, PyObject *kwds) return Py_GenericAlias(origin, arguments); } +static PyNumberMethods ga_as_number = { + .nb_or = (binaryfunc)_Py_union_type_or, // Add __or__ function +}; + // TODO: // - argument clinic? // - __doc__? @@ -586,6 +591,7 @@ PyTypeObject Py_GenericAliasType = { .tp_basicsize = sizeof(gaobject), .tp_dealloc = ga_dealloc, .tp_repr = ga_repr, + .tp_as_number = &ga_as_number, // allow X | Y of GenericAlias objs .tp_as_mapping = &ga_as_mapping, .tp_hash = ga_hash, .tp_call = ga_call, diff --git a/Objects/typeobject.c b/Objects/typeobject.c index 3822b8cf813..55bf9b3f389 100644 --- a/Objects/typeobject.c +++ b/Objects/typeobject.c @@ -6,7 +6,7 @@ #include "pycore_object.h" #include "pycore_pyerrors.h" #include "pycore_pystate.h" // _PyThreadState_GET() -#include "pycore_unionobject.h" // _Py_Union() +#include "pycore_unionobject.h" // _Py_Union(), _Py_union_type_or #include "frameobject.h" #include "structmember.h" // PyMemberDef @@ -3789,19 +3789,9 @@ type_is_gc(PyTypeObject *type) return type->tp_flags & Py_TPFLAGS_HEAPTYPE; } -static PyObject * -type_or(PyTypeObject* self, PyObject* param) { - PyObject *tuple = PyTuple_Pack(2, self, param); - if (tuple == NULL) { - return NULL; - } - PyObject *new_union = _Py_Union(tuple); - Py_DECREF(tuple); - return new_union; -} static PyNumberMethods type_as_number = { - .nb_or = (binaryfunc)type_or, // Add __or__ function + .nb_or = _Py_union_type_or, // Add __or__ function }; PyTypeObject PyType_Type = { diff --git a/Objects/unionobject.c b/Objects/unionobject.c index 1b7f8ab51a4..2308bfc9f2a 100644 --- a/Objects/unionobject.c +++ b/Objects/unionobject.c @@ -237,9 +237,19 @@ dedup_and_flatten_args(PyObject* args) PyObject* i_element = PyTuple_GET_ITEM(args, i); for (Py_ssize_t j = i + 1; j < arg_length; j++) { PyObject* j_element = PyTuple_GET_ITEM(args, j); - if (i_element == j_element) { - is_duplicate = 1; + int is_ga = Py_TYPE(i_element) == &Py_GenericAliasType && + Py_TYPE(j_element) == &Py_GenericAliasType; + // RichCompare to also deduplicate GenericAlias types (slower) + is_duplicate = is_ga ? PyObject_RichCompareBool(i_element, j_element, Py_EQ) + : i_element == j_element; + // Should only happen if RichCompare fails + if (is_duplicate < 0) { + Py_DECREF(args); + Py_DECREF(new_args); + return NULL; } + if (is_duplicate) + break; } if (!is_duplicate) { Py_INCREF(i_element); @@ -290,8 +300,8 @@ is_unionable(PyObject *obj) type == &_Py_UnionType); } -static PyObject * -type_or(PyTypeObject* self, PyObject* param) +PyObject * +_Py_union_type_or(PyObject* self, PyObject* param) { PyObject *tuple = PyTuple_Pack(2, self, param); if (tuple == NULL) { @@ -404,7 +414,7 @@ static PyMethodDef union_methods[] = { {0}}; static PyNumberMethods union_as_number = { - .nb_or = (binaryfunc)type_or, // Add __or__ function + .nb_or = _Py_union_type_or, // Add __or__ function }; PyTypeObject _Py_UnionType = {