From 4c8a9a2df3c31b1c29d0b3cf74523e3c8b3dae72 Mon Sep 17 00:00:00 2001 From: Mark Dickinson Date: Sat, 15 May 2010 17:02:38 +0000 Subject: [PATCH] Issue #8692: Improve performance of math.factorial: (1) use a different algorithm that roughly halves the total number of multiplications required and results in more balanced multiplications (2) use a lookup table for small arguments (3) fast accumulation of products in C integer arithmetic rather than PyLong arithmetic when possible. Typical speedup, from unscientific testing on a 64-bit laptop, is 4.5x to 6.5x for arguments in the range 100 - 10000. Patch by Daniel Stutzbach; extensive reviews by Alexander Belopolsky. --- Lib/test/test_math.py | 71 ++++++++++-- Misc/NEWS | 6 + Modules/mathmodule.c | 260 ++++++++++++++++++++++++++++++++++++++---- 3 files changed, 307 insertions(+), 30 deletions(-) diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py index 48d9b1abbdc..6c4443540f3 100644 --- a/Lib/test/test_math.py +++ b/Lib/test/test_math.py @@ -60,6 +60,56 @@ def ulps_check(expected, got, ulps=20): return "error = {} ulps; permitted error = {} ulps".format(ulps_error, ulps) +# Here's a pure Python version of the math.factorial algorithm, for +# documentation and comparison purposes. +# +# Formula: +# +# factorial(n) = factorial_odd_part(n) << (n - count_set_bits(n)) +# +# where +# +# factorial_odd_part(n) = product_{i >= 0} product_{0 < j <= n >> i; j odd} j +# +# The outer product above is an infinite product, but once i >= n.bit_length, +# (n >> i) < 1 and the corresponding term of the product is empty. So only the +# finitely many terms for 0 <= i < n.bit_length() contribute anything. +# +# We iterate downwards from i == n.bit_length() - 1 to i == 0. The inner +# product in the formula above starts at 1 for i == n.bit_length(); for each i +# < n.bit_length() we get the inner product for i from that for i + 1 by +# multiplying by all j in {n >> i+1 < j <= n >> i; j odd}. In Python terms, +# this set is range((n >> i+1) + 1 | 1, (n >> i) + 1 | 1, 2). + +def count_set_bits(n): + """Number of '1' bits in binary expansion of a nonnnegative integer.""" + return 1 + count_set_bits(n & n - 1) if n else 0 + +def partial_product(start, stop): + """Product of integers in range(start, stop, 2), computed recursively. + start and stop should both be odd, with start <= stop. + + """ + numfactors = (stop - start) >> 1 + if not numfactors: + return 1 + elif numfactors == 1: + return start + else: + mid = (start + numfactors) | 1 + return partial_product(start, mid) * partial_product(mid, stop) + +def py_factorial(n): + """Factorial of nonnegative integer n, via "Binary Split Factorial Formula" + described at http://www.luschny.de/math/factorial/binarysplitfact.html + + """ + inner = outer = 1 + for i in reversed(range(n.bit_length())): + inner *= partial_product((n >> i + 1) + 1 | 1, (n >> i) + 1 | 1) + outer *= inner + return outer << (n - count_set_bits(n)) + def acc_check(expected, got, rel_err=2e-15, abs_err = 5e-323): """Determine whether non-NaN floats a and b are equal to within a (small) rounding error. The default values for rel_err and @@ -365,18 +415,19 @@ class MathTests(unittest.TestCase): self.ftest('fabs(1)', math.fabs(1), 1) def testFactorial(self): - def fact(n): - result = 1 - for i in range(1, int(n)+1): - result *= i - return result - values = list(range(10)) + [50, 100, 500] - random.shuffle(values) - for x in values: - for cast in (int, float): - self.assertEqual(math.factorial(cast(x)), fact(x), (x, fact(x), math.factorial(x))) + self.assertEqual(math.factorial(0), 1) + self.assertEqual(math.factorial(0.0), 1) + total = 1 + for i in range(1, 1000): + total *= i + self.assertEqual(math.factorial(i), total) + self.assertEqual(math.factorial(float(i)), total) + self.assertEqual(math.factorial(i), py_factorial(i)) self.assertRaises(ValueError, math.factorial, -1) + self.assertRaises(ValueError, math.factorial, -1.0) self.assertRaises(ValueError, math.factorial, math.pi) + self.assertRaises(OverflowError, math.factorial, sys.maxsize+1) + self.assertRaises(OverflowError, math.factorial, 10e100) def testFloor(self): self.assertRaises(TypeError, math.floor) diff --git a/Misc/NEWS b/Misc/NEWS index 3da54ab4f73..a4a69388cf0 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -1132,6 +1132,12 @@ Library Extension Modules ----------------- +- Issue #8692: Optimize math.factorial: replace the previous naive + algorithm with an improved 'binary-split' algorithm that uses fewer + multiplications and allows many of the multiplications to be + performed using plain C integer arithmetic instead of PyLong + arithmetic. Also uses a lookup table for small arguments. + - Issue #8674: Fixed a number of incorrect or undefined-behaviour-inducing overflow checks in the audioop module. diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index 76d7906a19c..d57ad90b107 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -1129,18 +1129,239 @@ PyDoc_STRVAR(math_fsum_doc, Return an accurate floating point sum of values in the iterable.\n\ Assumes IEEE-754 floating point arithmetic."); +/* Return the smallest integer k such that n < 2**k, or 0 if n == 0. + * Equivalent to floor(lg(x))+1. Also equivalent to: bitwidth_of_type - + * count_leading_zero_bits(x) + */ + +/* XXX: This routine does more or less the same thing as + * bits_in_digit() in Objects/longobject.c. Someday it would be nice to + * consolidate them. On BSD, there's a library function called fls() + * that we could use, and GCC provides __builtin_clz(). + */ + +static unsigned long +bit_length(unsigned long n) +{ + unsigned long len = 0; + while (n != 0) { + ++len; + n >>= 1; + } + return len; +} + +static unsigned long +count_set_bits(unsigned long n) +{ + unsigned long count = 0; + while (n != 0) { + ++count; + n &= n - 1; /* clear least significant bit */ + } + return count; +} + +/* Divide-and-conquer factorial algorithm + * + * Based on the formula and psuedo-code provided at: + * http://www.luschny.de/math/factorial/binarysplitfact.html + * + * Faster algorithms exist, but they're more complicated and depend on + * a fast prime factoriazation algorithm. + * + * Notes on the algorithm + * ---------------------- + * + * factorial(n) is written in the form 2**k * m, with m odd. k and m are + * computed separately, and then combined using a left shift. + * + * The function factorial_odd_part computes the odd part m (i.e., the greatest + * odd divisor) of factorial(n), using the formula: + * + * factorial_odd_part(n) = + * + * product_{i >= 0} product_{0 < j <= n / 2**i, j odd} j + * + * Example: factorial_odd_part(20) = + * + * (1) * + * (1) * + * (1 * 3 * 5) * + * (1 * 3 * 5 * 7 * 9) + * (1 * 3 * 5 * 7 * 9 * 11 * 13 * 15 * 17 * 19) + * + * Here i goes from large to small: the first term corresponds to i=4 (any + * larger i gives an empty product), and the last term corresponds to i=0. + * Each term can be computed from the last by multiplying by the extra odd + * numbers required: e.g., to get from the penultimate term to the last one, + * we multiply by (11 * 13 * 15 * 17 * 19). + * + * To see a hint of why this formula works, here are the same numbers as above + * but with the even parts (i.e., the appropriate powers of 2) included. For + * each subterm in the product for i, we multiply that subterm by 2**i: + * + * factorial(20) = + * + * (16) * + * (8) * + * (4 * 12 * 20) * + * (2 * 6 * 10 * 14 * 18) * + * (1 * 3 * 5 * 7 * 9 * 11 * 13 * 15 * 17 * 19) + * + * The factorial_partial_product function computes the product of all odd j in + * range(start, stop) for given start and stop. It's used to compute the + * partial products like (11 * 13 * 15 * 17 * 19) in the example above. It + * operates recursively, repeatedly splitting the range into two roughly equal + * pieces until the subranges are small enough to be computed using only C + * integer arithmetic. + * + * The two-valuation k (i.e., the exponent of the largest power of 2 dividing + * the factorial) is computed independently in the main math_factorial + * function. By standard results, its value is: + * + * two_valuation = n//2 + n//4 + n//8 + .... + * + * It can be shown (e.g., by complete induction on n) that two_valuation is + * equal to n - count_set_bits(n), where count_set_bits(n) gives the number of + * '1'-bits in the binary expansion of n. + */ + +/* factorial_partial_product: Compute product(range(start, stop, 2)) using + * divide and conquer. Assumes start and stop are odd and stop > start. + * max_bits must be >= bit_length(stop - 2). */ + +static PyObject * +factorial_partial_product(unsigned long start, unsigned long stop, + unsigned long max_bits) +{ + unsigned long midpoint, num_operands; + PyObject *left = NULL, *right = NULL, *result = NULL; + + /* If the return value will fit an unsigned long, then we can + * multiply in a tight, fast loop where each multiply is O(1). + * Compute an upper bound on the number of bits required to store + * the answer. + * + * Storing some integer z requires floor(lg(z))+1 bits, which is + * conveniently the value returned by bit_length(z). The + * product x*y will require at most + * bit_length(x) + bit_length(y) bits to store, based + * on the idea that lg product = lg x + lg y. + * + * We know that stop - 2 is the largest number to be multiplied. From + * there, we have: bit_length(answer) <= num_operands * + * bit_length(stop - 2) + */ + + num_operands = (stop - start) / 2; + /* The "num_operands <= 8 * SIZEOF_LONG" check guards against the + * unlikely case of an overflow in num_operands * max_bits. */ + if (num_operands <= 8 * SIZEOF_LONG && + num_operands * max_bits <= 8 * SIZEOF_LONG) { + unsigned long j, total; + for (total = start, j = start + 2; j < stop; j += 2) + total *= j; + return PyLong_FromUnsignedLong(total); + } + + /* find midpoint of range(start, stop), rounded up to next odd number. */ + midpoint = (start + num_operands) | 1; + left = factorial_partial_product(start, midpoint, + bit_length(midpoint - 2)); + if (left == NULL) + goto error; + right = factorial_partial_product(midpoint, stop, max_bits); + if (right == NULL) + goto error; + result = PyNumber_Multiply(left, right); + + error: + Py_XDECREF(left); + Py_XDECREF(right); + return result; +} + +/* factorial_odd_part: compute the odd part of factorial(n). */ + +static PyObject * +factorial_odd_part(unsigned long n) +{ + long i; + unsigned long v, lower, upper; + PyObject *partial, *tmp, *inner, *outer; + + inner = PyLong_FromLong(1); + if (inner == NULL) + return NULL; + outer = inner; + Py_INCREF(outer); + + upper = 3; + for (i = bit_length(n) - 2; i >= 0; i--) { + v = n >> i; + if (v <= 2) + continue; + lower = upper; + /* (v + 1) | 1 = least odd integer strictly larger than n / 2**i */ + upper = (v + 1) | 1; + /* Here inner is the product of all odd integers j in the range (0, + n/2**(i+1)]. The factorial_partial_product call below gives the + product of all odd integers j in the range (n/2**(i+1), n/2**i]. */ + partial = factorial_partial_product(lower, upper, bit_length(upper-2)); + /* inner *= partial */ + if (partial == NULL) + goto error; + tmp = PyNumber_Multiply(inner, partial); + Py_DECREF(partial); + if (tmp == NULL) + goto error; + Py_DECREF(inner); + inner = tmp; + /* Now inner is the product of all odd integers j in the range (0, + n/2**i], giving the inner product in the formula above. */ + + /* outer *= inner; */ + tmp = PyNumber_Multiply(outer, inner); + if (tmp == NULL) + goto error; + Py_DECREF(outer); + outer = tmp; + } + + goto done; + + error: + Py_DECREF(outer); + done: + Py_DECREF(inner); + return outer; +} + +/* Lookup table for small factorial values */ + +static const unsigned long SmallFactorials[] = { + 1, 1, 2, 6, 24, 120, 720, 5040, 40320, + 362880, 3628800, 39916800, 479001600, +#if SIZEOF_LONG >= 8 + 6227020800, 87178291200, 1307674368000, + 20922789888000, 355687428096000, 6402373705728000, + 121645100408832000, 2432902008176640000 +#endif +}; + static PyObject * math_factorial(PyObject *self, PyObject *arg) { - long i, x; - PyObject *result, *iobj, *newresult; + long x; + PyObject *result, *odd_part, *two_valuation; if (PyFloat_Check(arg)) { PyObject *lx; double dx = PyFloat_AS_DOUBLE((PyFloatObject *)arg); if (!(Py_IS_FINITE(dx) && dx == floor(dx))) { PyErr_SetString(PyExc_ValueError, - "factorial() only accepts integral values"); + "factorial() only accepts integral values"); return NULL; } lx = PyLong_FromDouble(dx); @@ -1156,29 +1377,28 @@ math_factorial(PyObject *self, PyObject *arg) return NULL; if (x < 0) { PyErr_SetString(PyExc_ValueError, - "factorial() not defined for negative values"); + "factorial() not defined for negative values"); return NULL; } - result = (PyObject *)PyLong_FromLong(1); - if (result == NULL) + /* use lookup table if x is small */ + if (x < (long)(sizeof(SmallFactorials)/sizeof(SmallFactorials[0]))) + return PyLong_FromUnsignedLong(SmallFactorials[x]); + + /* else express in the form odd_part * 2**two_valuation, and compute as + odd_part << two_valuation. */ + odd_part = factorial_odd_part(x); + if (odd_part == NULL) + return NULL; + two_valuation = PyLong_FromLong(x - count_set_bits(x)); + if (two_valuation == NULL) { + Py_DECREF(odd_part); return NULL; - for (i=1 ; i<=x ; i++) { - iobj = (PyObject *)PyLong_FromLong(i); - if (iobj == NULL) - goto error; - newresult = PyNumber_Multiply(result, iobj); - Py_DECREF(iobj); - if (newresult == NULL) - goto error; - Py_DECREF(result); - result = newresult; } + result = PyNumber_Lshift(odd_part, two_valuation); + Py_DECREF(two_valuation); + Py_DECREF(odd_part); return result; - -error: - Py_DECREF(result); - return NULL; } PyDoc_STRVAR(math_factorial_doc,