From d02c5e9b55a8651b7d396ac3f2bdedf1fc1780b5 Mon Sep 17 00:00:00 2001 From: Mark Dickinson Date: Sat, 15 Jan 2022 09:58:04 +0000 Subject: [PATCH] bpo-46258: Streamline isqrt fast path (#30333) --- .../2022-01-04-18-05-25.bpo-46258.DYgwRo.rst | 2 + Modules/mathmodule.c | 57 ++++++++++++++----- 2 files changed, 45 insertions(+), 14 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2022-01-04-18-05-25.bpo-46258.DYgwRo.rst diff --git a/Misc/NEWS.d/next/Library/2022-01-04-18-05-25.bpo-46258.DYgwRo.rst b/Misc/NEWS.d/next/Library/2022-01-04-18-05-25.bpo-46258.DYgwRo.rst new file mode 100644 index 00000000000..b918ed1a5d9 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2022-01-04-18-05-25.bpo-46258.DYgwRo.rst @@ -0,0 +1,2 @@ +Speed up :func:`math.isqrt` for small positive integers by replacing two +division steps with a lookup table. diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index 3ab1a077604..0c7d4de0686 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -1718,20 +1718,49 @@ completes the proof sketch. */ +/* + The _approximate_isqrt_tab table provides approximate square roots for + 16-bit integers. For any n in the range 2**14 <= n < 2**16, the value + + a = _approximate_isqrt_tab[(n >> 8) - 64] + + is an approximate square root of n, satisfying (a - 1)**2 < n < (a + 1)**2. + + The table was computed in Python using the expression: + + [min(round(sqrt(256*n + 128)), 255) for n in range(64, 256)] +*/ + +static const uint8_t _approximate_isqrt_tab[192] = { + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, + 140, 141, 142, 143, 144, 144, 145, 146, 147, 148, 149, 150, + 151, 151, 152, 153, 154, 155, 156, 156, 157, 158, 159, 160, + 160, 161, 162, 163, 164, 164, 165, 166, 167, 167, 168, 169, + 170, 170, 171, 172, 173, 173, 174, 175, 176, 176, 177, 178, + 179, 179, 180, 181, 181, 182, 183, 183, 184, 185, 186, 186, + 187, 188, 188, 189, 190, 190, 191, 192, 192, 193, 194, 194, + 195, 196, 196, 197, 198, 198, 199, 200, 200, 201, 201, 202, + 203, 203, 204, 205, 205, 206, 206, 207, 208, 208, 209, 210, + 210, 211, 211, 212, 213, 213, 214, 214, 215, 216, 216, 217, + 217, 218, 219, 219, 220, 220, 221, 221, 222, 223, 223, 224, + 224, 225, 225, 226, 227, 227, 228, 228, 229, 229, 230, 230, + 231, 232, 232, 233, 233, 234, 234, 235, 235, 236, 237, 237, + 238, 238, 239, 239, 240, 240, 241, 241, 242, 242, 243, 243, + 244, 244, 245, 246, 246, 247, 247, 248, 248, 249, 249, 250, + 250, 251, 251, 252, 252, 253, 253, 254, 254, 255, 255, 255, +}; /* Approximate square root of a large 64-bit integer. Given `n` satisfying `2**62 <= n < 2**64`, return `a` satisfying `(a - 1)**2 < n < (a + 1)**2`. */ -static uint64_t +static inline uint32_t _approximate_isqrt(uint64_t n) { - uint32_t u = 1U + (n >> 62); - u = (u << 1) + (n >> 59) / u; - u = (u << 3) + (n >> 53) / u; - u = (u << 7) + (n >> 41) / u; - return (u << 15) + (n >> 17) / u; + uint32_t u = _approximate_isqrt_tab[(n >> 56) - 64]; + u = (u << 7) + (uint32_t)(n >> 41) / u; + return (u << 15) + (uint32_t)((n >> 17) / u); } /*[clinic input] @@ -1749,7 +1778,8 @@ math_isqrt(PyObject *module, PyObject *n) { int a_too_large, c_bit_length; size_t c, d; - uint64_t m, u; + uint64_t m; + uint32_t u; PyObject *a = NULL, *b; n = _PyNumber_Index(n); @@ -1776,18 +1806,17 @@ math_isqrt(PyObject *module, PyObject *n) c = (c - 1U) / 2U; /* Fast path: if c <= 31 then n < 2**64 and we can compute directly with a - fast, almost branch-free algorithm. In the final correction, we use `u*u - - 1 >= m` instead of the simpler `u*u > m` in order to get the correct - result in the corner case where `u=2**32`. */ + fast, almost branch-free algorithm. */ if (c <= 31U) { + int shift = 31 - (int)c; m = (uint64_t)PyLong_AsUnsignedLongLong(n); Py_DECREF(n); if (m == (uint64_t)(-1) && PyErr_Occurred()) { return NULL; } - u = _approximate_isqrt(m << (62U - 2U*c)) >> (31U - c); - u -= u * u - 1U >= m; - return PyLong_FromUnsignedLongLong((unsigned long long)u); + u = _approximate_isqrt(m << 2*shift) >> shift; + u -= (uint64_t)u * u > m; + return PyLong_FromUnsignedLong(u); } /* Slow path: n >= 2**64. We perform the first five iterations in C integer @@ -1811,7 +1840,7 @@ math_isqrt(PyObject *module, PyObject *n) goto error; } u = _approximate_isqrt(m) >> (31U - d); - a = PyLong_FromUnsignedLongLong((unsigned long long)u); + a = PyLong_FromUnsignedLong(u); if (a == NULL) { goto error; }