diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index ed114767530..76d821c65b4 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -3056,6 +3056,12 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k) "n must be a non-negative integer"); goto error; } + if (Py_SIZE(k) < 0) { + PyErr_SetString(PyExc_ValueError, + "k must be a non-negative integer"); + goto error; + } + cmp = PyObject_RichCompareBool(n, k, Py_LT); if (cmp != 0) { if (cmp > 0) { @@ -3072,11 +3078,8 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k) LLONG_MAX); goto error; } - else if (overflow < 0 || factors < 0) { - if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_ValueError, - "k must be a non-negative integer"); - } + else if (factors == -1) { + /* k is nonnegative, so a return value of -1 can only indicate error */ goto error; } @@ -3176,6 +3179,12 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k) "n must be a non-negative integer"); goto error; } + if (Py_SIZE(k) < 0) { + PyErr_SetString(PyExc_ValueError, + "k must be a non-negative integer"); + goto error; + } + /* k = min(k, n - k) */ temp = PyNumber_Subtract(n, k); if (temp == NULL) { @@ -3204,11 +3213,8 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k) LLONG_MAX); goto error; } - else if (overflow < 0 || factors < 0) { - if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_ValueError, - "k must be a non-negative integer"); - } + if (factors == -1) { + /* k is nonnegative, so a return value of -1 can only indicate error */ goto error; }