From 5c08ce9bf712acbb3f05a3a57baf51fcb534cdf0 Mon Sep 17 00:00:00 2001 From: Mark Dickinson Date: Sun, 19 May 2019 17:51:56 +0100 Subject: [PATCH] bpo-36957: Speed up math.isqrt (#13405) * Add math.isqrt function computing the integer square root. * Code cleanup: remove redundant comments, rename some variables. * Tighten up code a bit more; use Py_XDECREF to simplify error handling. * Update Modules/mathmodule.c Co-Authored-By: Serhiy Storchaka * Update Modules/mathmodule.c Use real argument clinic type instead of an alias Co-Authored-By: Serhiy Storchaka * Add proof sketch * Updates from review. * Correct and expand documentation. * Fix bad reference handling on error; make some variables block-local; other tidying. * Style and consistency fixes. * Add missing error check; don't try to DECREF a NULL a * Simplify some error returns. * Another two test cases: - clarify that floats are rejected even if they happen to be squares of small integers - TypeError beats ValueError for a negative float * Add fast path for small inputs. Needs tests. * Speed up isqrt for n >= 2**64 as well; add extra tests. * Reduce number of test-cases to avoid dominating the run-time of test_math. * Don't perform unnecessary extra iterations when computing c_bit_length. * Abstract common uint64_t code out into a separate function. * Cleanup. * Add a missing Py_DECREF in an error branch. More cleanup. * Update Modules/mathmodule.c Add missing `static` declaration to helper function. Co-Authored-By: Serhiy Storchaka * Add missing backtick. --- Lib/test/test_math.py | 1 + Modules/mathmodule.c | 64 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py index a11a3447856..853a0e62f82 100644 --- a/Lib/test/test_math.py +++ b/Lib/test/test_math.py @@ -917,6 +917,7 @@ class MathTests(unittest.TestCase): test_values = ( list(range(1000)) + list(range(10**6 - 1000, 10**6 + 1000)) + + [2**e + i for e in range(60, 200) for i in range(-40, 40)] + [3**9999, 10**5001] ) diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index 7a0044a9fcf..a153e984ca5 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -1620,6 +1620,22 @@ completes the proof sketch. */ + +/* Approximate square root of a large 64-bit integer. + + Given `n` satisfying `2**62 <= n < 2**64`, return `a` + satisfying `(a - 1)**2 < n < (a + 1)**2`. */ + +static uint64_t +_approximate_isqrt(uint64_t n) +{ + uint32_t u = 1U + (n >> 62); + u = (u << 1) + (n >> 59) / u; + u = (u << 3) + (n >> 53) / u; + u = (u << 7) + (n >> 41) / u; + return (u << 15) + (n >> 17) / u; +} + /*[clinic input] math.isqrt @@ -1633,8 +1649,9 @@ static PyObject * math_isqrt(PyObject *module, PyObject *n) /*[clinic end generated code: output=35a6f7f980beab26 input=5b6e7ae4fa6c43d6]*/ { - int a_too_large, s; + int a_too_large, c_bit_length; size_t c, d; + uint64_t m, u; PyObject *a = NULL, *b; n = PyNumber_Index(n); @@ -1653,24 +1670,55 @@ math_isqrt(PyObject *module, PyObject *n) return PyLong_FromLong(0); } + /* c = (n.bit_length() - 1) // 2 */ c = _PyLong_NumBits(n); if (c == (size_t)(-1)) { goto error; } c = (c - 1U) / 2U; - /* s = c.bit_length() */ - s = 0; - while ((c >> s) > 0) { - ++s; + /* Fast path: if c <= 31 then n < 2**64 and we can compute directly with a + fast, almost branch-free algorithm. In the final correction, we use `u*u + - 1 >= m` instead of the simpler `u*u > m` in order to get the correct + result in the corner case where `u=2**32`. */ + if (c <= 31U) { + m = (uint64_t)PyLong_AsUnsignedLongLong(n); + Py_DECREF(n); + if (m == (uint64_t)(-1) && PyErr_Occurred()) { + return NULL; + } + u = _approximate_isqrt(m << (62U - 2U*c)) >> (31U - c); + u -= u * u - 1U >= m; + return PyLong_FromUnsignedLongLong((unsigned long long)u); } - a = PyLong_FromLong(1); + /* Slow path: n >= 2**64. We perform the first five iterations in C integer + arithmetic, then switch to using Python long integers. */ + + /* From n >= 2**64 it follows that c.bit_length() >= 6. */ + c_bit_length = 6; + while ((c >> c_bit_length) > 0U) { + ++c_bit_length; + } + + /* Initialise d and a. */ + d = c >> (c_bit_length - 5); + b = _PyLong_Rshift(n, 2U*c - 62U); + if (b == NULL) { + goto error; + } + m = (uint64_t)PyLong_AsUnsignedLongLong(b); + Py_DECREF(b); + if (m == (uint64_t)(-1) && PyErr_Occurred()) { + goto error; + } + u = _approximate_isqrt(m) >> (31U - d); + a = PyLong_FromUnsignedLongLong((unsigned long long)u); if (a == NULL) { goto error; } - d = 0; - while (--s >= 0) { + + for (int s = c_bit_length - 6; s >= 0; --s) { PyObject *q; size_t e = d;