diff --git a/Lib/test/test_type_aliases.py b/Lib/test/test_type_aliases.py index ebb65d8c6cf..230bbe646ba 100644 --- a/Lib/test/test_type_aliases.py +++ b/Lib/test/test_type_aliases.py @@ -4,7 +4,9 @@ import unittest from test.support import check_syntax_error, run_code from test.typinganndata import mod_generics_cache -from typing import Callable, TypeAliasType, TypeVar, get_args +from typing import ( + Callable, TypeAliasType, TypeVar, TypeVarTuple, ParamSpec, get_args, +) class TypeParamsInvalidTest(unittest.TestCase): @@ -225,6 +227,46 @@ class TypeAliasConstructorTest(unittest.TestCase): ): TA[int] + def test_type_params_order_with_defaults(self): + HasNoDefaultT = TypeVar("HasNoDefaultT") + WithDefaultT = TypeVar("WithDefaultT", default=int) + + HasNoDefaultP = ParamSpec("HasNoDefaultP") + WithDefaultP = ParamSpec("WithDefaultP", default=HasNoDefaultP) + + HasNoDefaultTT = TypeVarTuple("HasNoDefaultTT") + WithDefaultTT = TypeVarTuple("WithDefaultTT", default=HasNoDefaultTT) + + for type_params in [ + (HasNoDefaultT, WithDefaultT), + (HasNoDefaultP, WithDefaultP), + (HasNoDefaultTT, WithDefaultTT), + ]: + with self.subTest(type_params=type_params): + TypeAliasType("A", int, type_params=type_params) # ok + + msg = "follows default type parameter" + for type_params in [ + (WithDefaultT, HasNoDefaultT), + (WithDefaultP, HasNoDefaultP), + (WithDefaultTT, HasNoDefaultTT), + (WithDefaultT, HasNoDefaultP), # different types + ]: + with self.subTest(type_params=type_params): + with self.assertRaisesRegex(TypeError, msg): + TypeAliasType("A", int, type_params=type_params) + + def test_expects_type_like(self): + T = TypeVar("T") + + msg = "Expected a type param" + with self.assertRaisesRegex(TypeError, msg): + TypeAliasType("A", int, type_params=(1,)) + with self.assertRaisesRegex(TypeError, msg): + TypeAliasType("A", int, type_params=(1, 2)) + with self.assertRaisesRegex(TypeError, msg): + TypeAliasType("A", int, type_params=(T, 2)) + def test_keywords(self): TA = TypeAliasType(name="TA", value=int) self.assertEqual(TA.__name__, "TA") diff --git a/Misc/NEWS.d/next/Library/2024-09-30-20-46-32.gh-issue-124787.3FnJnP.rst b/Misc/NEWS.d/next/Library/2024-09-30-20-46-32.gh-issue-124787.3FnJnP.rst new file mode 100644 index 00000000000..d9d1bbcf5a2 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-09-30-20-46-32.gh-issue-124787.3FnJnP.rst @@ -0,0 +1,4 @@ +Fix :class:`typing.TypeAliasType` with incorrect ``type_params`` argument. +Now it raises a :exc:`TypeError` when a type parameter without a default +follows one with a default, and when an entry in the ``type_params`` tuple +is not a type parameter object. diff --git a/Objects/typevarobject.c b/Objects/typevarobject.c index 51d93ed8b5b..91cc37c9a72 100644 --- a/Objects/typevarobject.c +++ b/Objects/typevarobject.c @@ -1799,6 +1799,24 @@ _Py_make_typevartuple(PyThreadState *Py_UNUSED(ignored), PyObject *v) return (PyObject *)typevartuple_alloc(v, NULL, NULL); } +static PyObject * +get_type_param_default(PyThreadState *ts, PyObject *typeparam) { + // Does not modify refcount of existing objects. + if (Py_IS_TYPE(typeparam, ts->interp->cached_objects.typevar_type)) { + return typevar_default((typevarobject *)typeparam, NULL); + } + else if (Py_IS_TYPE(typeparam, ts->interp->cached_objects.paramspec_type)) { + return paramspec_default((paramspecobject *)typeparam, NULL); + } + else if (Py_IS_TYPE(typeparam, ts->interp->cached_objects.typevartuple_type)) { + return typevartuple_default((typevartupleobject *)typeparam, NULL); + } + else { + PyErr_Format(PyExc_TypeError, "Expected a type param, got %R", typeparam); + return NULL; + } +} + static void typealias_dealloc(PyObject *self) { @@ -1906,6 +1924,65 @@ static PyGetSetDef typealias_getset[] = { {0} }; +static PyObject * +typealias_check_type_params(PyObject *type_params, int *err) { + // Can return type_params or NULL without exception set. + // Does not change the reference count of type_params, + // sets `*err` to 1 when error happens and sets an exception, + // otherwise `*err` is set to 0. + *err = 0; + if (type_params == NULL) { + return NULL; + } + + assert(PyTuple_Check(type_params)); + Py_ssize_t length = PyTuple_GET_SIZE(type_params); + if (!length) { // 0-length tuples are the same as `NULL`. + return NULL; + } + + PyThreadState *ts = _PyThreadState_GET(); + int default_seen = 0; + for (Py_ssize_t index = 0; index < length; index++) { + PyObject *type_param = PyTuple_GET_ITEM(type_params, index); + PyObject *dflt = get_type_param_default(ts, type_param); + if (dflt == NULL) { + *err = 1; + return NULL; + } + if (dflt == &_Py_NoDefaultStruct) { + if (default_seen) { + *err = 1; + PyErr_Format(PyExc_TypeError, + "non-default type parameter '%R' " + "follows default type parameter", + type_param); + return NULL; + } + } else { + default_seen = 1; + Py_DECREF(dflt); + } + } + + return type_params; +} + +static PyObject * +typelias_convert_type_params(PyObject *type_params) +{ + if ( + type_params == NULL + || Py_IsNone(type_params) + || (PyTuple_Check(type_params) && PyTuple_GET_SIZE(type_params) == 0) + ) { + return NULL; + } + else { + return type_params; + } +} + static typealiasobject * typealias_alloc(PyObject *name, PyObject *type_params, PyObject *compute_value, PyObject *value, PyObject *module) @@ -1915,16 +1992,7 @@ typealias_alloc(PyObject *name, PyObject *type_params, PyObject *compute_value, return NULL; } ta->name = Py_NewRef(name); - if ( - type_params == NULL - || Py_IsNone(type_params) - || (PyTuple_Check(type_params) && PyTuple_GET_SIZE(type_params) == 0) - ) { - ta->type_params = NULL; - } - else { - ta->type_params = Py_NewRef(type_params); - } + ta->type_params = Py_XNewRef(type_params); ta->compute_value = Py_XNewRef(compute_value); ta->value = Py_XNewRef(value); ta->module = Py_XNewRef(module); @@ -2002,11 +2070,18 @@ typealias_new_impl(PyTypeObject *type, PyObject *name, PyObject *value, PyErr_SetString(PyExc_TypeError, "type_params must be a tuple"); return NULL; } + + int err = 0; + PyObject *checked_params = typealias_check_type_params(type_params, &err); + if (err) { + return NULL; + } + PyObject *module = caller(); if (module == NULL) { return NULL; } - PyObject *ta = (PyObject *)typealias_alloc(name, type_params, NULL, value, + PyObject *ta = (PyObject *)typealias_alloc(name, checked_params, NULL, value, module); Py_DECREF(module); return ta; @@ -2072,7 +2147,7 @@ _Py_make_typealias(PyThreadState* unused, PyObject *args) assert(PyTuple_GET_SIZE(args) == 3); PyObject *name = PyTuple_GET_ITEM(args, 0); assert(PyUnicode_Check(name)); - PyObject *type_params = PyTuple_GET_ITEM(args, 1); + PyObject *type_params = typelias_convert_type_params(PyTuple_GET_ITEM(args, 1)); PyObject *compute_value = PyTuple_GET_ITEM(args, 2); assert(PyFunction_Check(compute_value)); return (PyObject *)typealias_alloc(name, type_params, compute_value, NULL, NULL);