mirror of https://github.com/python/cpython
bpo-35431: Refactor math.comb() implementation. (GH-13725)
* Fixed some bugs. * Added support for index-likes objects. * Improved error messages. * Cleaned up and optimized the code. * Added more tests.
This commit is contained in:
parent
9843bc110d
commit
2b843ac0ae
|
@ -238,11 +238,11 @@ Number-theoretic and representation functions
|
|||
and without order.
|
||||
|
||||
Also called the binomial coefficient. It is mathematically equal to the expression
|
||||
``n! / (k! (n - k)!)``. It is equivalent to the coefficient of k-th term in
|
||||
``n! / (k! (n - k)!)``. It is equivalent to the coefficient of the *k*-th term in the
|
||||
polynomial expansion of the expression ``(1 + x) ** n``.
|
||||
|
||||
Raises :exc:`TypeError` if the arguments not integers.
|
||||
Raises :exc:`ValueError` if the arguments are negative or if k > n.
|
||||
Raises :exc:`ValueError` if the arguments are negative or if *k* > *n*.
|
||||
|
||||
.. versionadded:: 3.8
|
||||
|
||||
|
|
|
@ -1893,9 +1893,11 @@ class IsCloseTests(unittest.TestCase):
|
|||
# Raises TypeError if any argument is non-integer or argument count is
|
||||
# not 2
|
||||
self.assertRaises(TypeError, comb, 10, 1.0)
|
||||
self.assertRaises(TypeError, comb, 10, decimal.Decimal(1.0))
|
||||
self.assertRaises(TypeError, comb, 10, "1")
|
||||
self.assertRaises(TypeError, comb, "10", 1)
|
||||
self.assertRaises(TypeError, comb, 10.0, 1)
|
||||
self.assertRaises(TypeError, comb, decimal.Decimal(10.0), 1)
|
||||
self.assertRaises(TypeError, comb, "10", 1)
|
||||
|
||||
self.assertRaises(TypeError, comb, 10)
|
||||
self.assertRaises(TypeError, comb, 10, 1, 3)
|
||||
|
@ -1903,15 +1905,28 @@ class IsCloseTests(unittest.TestCase):
|
|||
|
||||
# Raises Value error if not k or n are negative numbers
|
||||
self.assertRaises(ValueError, comb, -1, 1)
|
||||
self.assertRaises(ValueError, comb, -10*10, 1)
|
||||
self.assertRaises(ValueError, comb, -2**1000, 1)
|
||||
self.assertRaises(ValueError, comb, 1, -1)
|
||||
self.assertRaises(ValueError, comb, 1, -10*10)
|
||||
self.assertRaises(ValueError, comb, 1, -2**1000)
|
||||
|
||||
# Raises value error if k is greater than n
|
||||
self.assertRaises(ValueError, comb, 1, 10**10)
|
||||
self.assertRaises(ValueError, comb, 0, 1)
|
||||
self.assertRaises(ValueError, comb, 1, 2)
|
||||
self.assertRaises(ValueError, comb, 1, 2**1000)
|
||||
|
||||
n = 2**1000
|
||||
self.assertEqual(comb(n, 0), 1)
|
||||
self.assertEqual(comb(n, 1), n)
|
||||
self.assertEqual(comb(n, 2), n * (n-1) // 2)
|
||||
self.assertEqual(comb(n, n), 1)
|
||||
self.assertEqual(comb(n, n-1), n)
|
||||
self.assertEqual(comb(n, n-2), n * (n-1) // 2)
|
||||
self.assertRaises((OverflowError, MemoryError), comb, n, n//2)
|
||||
|
||||
for n, k in (True, True), (True, False), (False, False):
|
||||
self.assertEqual(comb(n, k), 1)
|
||||
self.assertIs(type(comb(n, k)), int)
|
||||
self.assertEqual(comb(MyIndexable(5), MyIndexable(2)), 10)
|
||||
self.assertIs(type(comb(MyIndexable(5), MyIndexable(2))), int)
|
||||
|
||||
|
||||
def test_main():
|
||||
|
|
|
@ -639,10 +639,10 @@ exit:
|
|||
}
|
||||
|
||||
PyDoc_STRVAR(math_comb__doc__,
|
||||
"comb($module, /, n, k)\n"
|
||||
"comb($module, n, k, /)\n"
|
||||
"--\n"
|
||||
"\n"
|
||||
"Number of ways to choose *k* items from *n* items without repetition and without order.\n"
|
||||
"Number of ways to choose k items from n items without repetition and without order.\n"
|
||||
"\n"
|
||||
"Also called the binomial coefficient. It is mathematically equal to the expression\n"
|
||||
"n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in\n"
|
||||
|
@ -652,38 +652,26 @@ PyDoc_STRVAR(math_comb__doc__,
|
|||
"Raises ValueError if the arguments are negative or if k > n.");
|
||||
|
||||
#define MATH_COMB_METHODDEF \
|
||||
{"comb", (PyCFunction)(void(*)(void))math_comb, METH_FASTCALL|METH_KEYWORDS, math_comb__doc__},
|
||||
{"comb", (PyCFunction)(void(*)(void))math_comb, METH_FASTCALL, math_comb__doc__},
|
||||
|
||||
static PyObject *
|
||||
math_comb_impl(PyObject *module, PyObject *n, PyObject *k);
|
||||
|
||||
static PyObject *
|
||||
math_comb(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
|
||||
math_comb(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
|
||||
{
|
||||
PyObject *return_value = NULL;
|
||||
static const char * const _keywords[] = {"n", "k", NULL};
|
||||
static _PyArg_Parser _parser = {NULL, _keywords, "comb", 0};
|
||||
PyObject *argsbuf[2];
|
||||
PyObject *n;
|
||||
PyObject *k;
|
||||
|
||||
args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 2, 2, 0, argsbuf);
|
||||
if (!args) {
|
||||
goto exit;
|
||||
}
|
||||
if (!PyLong_Check(args[0])) {
|
||||
_PyArg_BadArgument("comb", 1, "int", args[0]);
|
||||
if (!_PyArg_CheckPositional("comb", nargs, 2, 2)) {
|
||||
goto exit;
|
||||
}
|
||||
n = args[0];
|
||||
if (!PyLong_Check(args[1])) {
|
||||
_PyArg_BadArgument("comb", 2, "int", args[1]);
|
||||
goto exit;
|
||||
}
|
||||
k = args[1];
|
||||
return_value = math_comb_impl(module, n, k);
|
||||
|
||||
exit:
|
||||
return return_value;
|
||||
}
|
||||
/*[clinic end generated code: output=00aa76356759617a input=a9049054013a1b77]*/
|
||||
/*[clinic end generated code: output=6709521e5e1d90ec input=a9049054013a1b77]*/
|
||||
|
|
|
@ -3001,10 +3001,11 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
|
|||
/*[clinic input]
|
||||
math.comb
|
||||
|
||||
n: object(subclass_of='&PyLong_Type')
|
||||
k: object(subclass_of='&PyLong_Type')
|
||||
n: object
|
||||
k: object
|
||||
/
|
||||
|
||||
Number of ways to choose *k* items from *n* items without repetition and without order.
|
||||
Number of ways to choose k items from n items without repetition and without order.
|
||||
|
||||
Also called the binomial coefficient. It is mathematically equal to the expression
|
||||
n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in
|
||||
|
@ -3017,103 +3018,109 @@ Raises ValueError if the arguments are negative or if k > n.
|
|||
|
||||
static PyObject *
|
||||
math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
|
||||
/*[clinic end generated code: output=bd2cec8d854f3493 input=565f340f98efb5b5]*/
|
||||
/*[clinic end generated code: output=bd2cec8d854f3493 input=2f336ac9ec8242f9]*/
|
||||
{
|
||||
PyObject *val = NULL,
|
||||
*temp_obj1 = NULL,
|
||||
*temp_obj2 = NULL,
|
||||
*dump_var = NULL;
|
||||
PyObject *result = NULL, *factor = NULL, *temp;
|
||||
int overflow, cmp;
|
||||
long long i, terms;
|
||||
long long i, factors;
|
||||
|
||||
cmp = PyObject_RichCompareBool(n, k, Py_LT);
|
||||
if (cmp < 0) {
|
||||
goto fail_comb;
|
||||
n = PyNumber_Index(n);
|
||||
if (n == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
else if (cmp > 0) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"n must be an integer greater than or equal to k");
|
||||
goto fail_comb;
|
||||
k = PyNumber_Index(k);
|
||||
if (k == NULL) {
|
||||
Py_DECREF(n);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/* b = min(b, a - b) */
|
||||
dump_var = PyNumber_Subtract(n, k);
|
||||
if (dump_var == NULL) {
|
||||
goto fail_comb;
|
||||
if (Py_SIZE(n) < 0) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"n must be a non-negative integer");
|
||||
goto error;
|
||||
}
|
||||
cmp = PyObject_RichCompareBool(k, dump_var, Py_GT);
|
||||
if (cmp < 0) {
|
||||
goto fail_comb;
|
||||
/* k = min(k, n - k) */
|
||||
temp = PyNumber_Subtract(n, k);
|
||||
if (temp == NULL) {
|
||||
goto error;
|
||||
}
|
||||
else if (cmp > 0) {
|
||||
k = dump_var;
|
||||
dump_var = NULL;
|
||||
if (Py_SIZE(temp) < 0) {
|
||||
Py_DECREF(temp);
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"k must be an integer less than or equal to n");
|
||||
goto error;
|
||||
}
|
||||
cmp = PyObject_RichCompareBool(k, temp, Py_GT);
|
||||
if (cmp > 0) {
|
||||
Py_SETREF(k, temp);
|
||||
}
|
||||
else {
|
||||
Py_DECREF(dump_var);
|
||||
dump_var = NULL;
|
||||
Py_DECREF(temp);
|
||||
if (cmp < 0) {
|
||||
goto error;
|
||||
}
|
||||
}
|
||||
|
||||
terms = PyLong_AsLongLongAndOverflow(k, &overflow);
|
||||
if (terms < 0 && PyErr_Occurred()) {
|
||||
goto fail_comb;
|
||||
}
|
||||
else if (overflow > 0) {
|
||||
factors = PyLong_AsLongLongAndOverflow(k, &overflow);
|
||||
if (overflow > 0) {
|
||||
PyErr_Format(PyExc_OverflowError,
|
||||
"minimum(n - k, k) must not exceed %lld",
|
||||
"min(n - k, k) must not exceed %lld",
|
||||
LLONG_MAX);
|
||||
goto fail_comb;
|
||||
goto error;
|
||||
}
|
||||
else if (overflow < 0 || terms < 0) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"k must be a positive integer");
|
||||
goto fail_comb;
|
||||
else if (overflow < 0 || factors < 0) {
|
||||
if (!PyErr_Occurred()) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"k must be a non-negative integer");
|
||||
}
|
||||
goto error;
|
||||
}
|
||||
|
||||
if (terms == 0) {
|
||||
return PyNumber_Long(_PyLong_One);
|
||||
if (factors == 0) {
|
||||
result = PyLong_FromLong(1);
|
||||
goto done;
|
||||
}
|
||||
|
||||
val = PyNumber_Long(n);
|
||||
for (i = 1; i < terms; ++i) {
|
||||
temp_obj1 = PyLong_FromSsize_t(i);
|
||||
if (temp_obj1 == NULL) {
|
||||
goto fail_comb;
|
||||
}
|
||||
temp_obj2 = PyNumber_Subtract(n, temp_obj1);
|
||||
if (temp_obj2 == NULL) {
|
||||
goto fail_comb;
|
||||
}
|
||||
dump_var = val;
|
||||
val = PyNumber_Multiply(val, temp_obj2);
|
||||
if (val == NULL) {
|
||||
goto fail_comb;
|
||||
}
|
||||
Py_DECREF(dump_var);
|
||||
dump_var = NULL;
|
||||
Py_DECREF(temp_obj2);
|
||||
temp_obj2 = PyLong_FromUnsignedLongLong((unsigned long long)(i + 1));
|
||||
if (temp_obj2 == NULL) {
|
||||
goto fail_comb;
|
||||
}
|
||||
dump_var = val;
|
||||
val = PyNumber_FloorDivide(val, temp_obj2);
|
||||
if (val == NULL) {
|
||||
goto fail_comb;
|
||||
}
|
||||
Py_DECREF(dump_var);
|
||||
Py_DECREF(temp_obj1);
|
||||
Py_DECREF(temp_obj2);
|
||||
result = n;
|
||||
Py_INCREF(result);
|
||||
if (factors == 1) {
|
||||
goto done;
|
||||
}
|
||||
|
||||
return val;
|
||||
factor = n;
|
||||
Py_INCREF(factor);
|
||||
for (i = 1; i < factors; ++i) {
|
||||
Py_SETREF(factor, PyNumber_Subtract(factor, _PyLong_One));
|
||||
if (factor == NULL) {
|
||||
goto error;
|
||||
}
|
||||
Py_SETREF(result, PyNumber_Multiply(result, factor));
|
||||
if (result == NULL) {
|
||||
goto error;
|
||||
}
|
||||
|
||||
fail_comb:
|
||||
Py_XDECREF(val);
|
||||
Py_XDECREF(dump_var);
|
||||
Py_XDECREF(temp_obj1);
|
||||
Py_XDECREF(temp_obj2);
|
||||
temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1);
|
||||
if (temp == NULL) {
|
||||
goto error;
|
||||
}
|
||||
Py_SETREF(result, PyNumber_FloorDivide(result, temp));
|
||||
Py_DECREF(temp);
|
||||
if (result == NULL) {
|
||||
goto error;
|
||||
}
|
||||
}
|
||||
Py_DECREF(factor);
|
||||
|
||||
done:
|
||||
Py_DECREF(n);
|
||||
Py_DECREF(k);
|
||||
return result;
|
||||
|
||||
error:
|
||||
Py_XDECREF(factor);
|
||||
Py_XDECREF(result);
|
||||
Py_DECREF(n);
|
||||
Py_DECREF(k);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue