From 73934b9da07daefb203e7d26089e7486a1ce4fdf Mon Sep 17 00:00:00 2001 From: Mark Dickinson Date: Sat, 18 May 2019 12:29:50 +0100 Subject: [PATCH] bpo-36887: add math.isqrt (GH-13244) * Add math.isqrt function computing the integer square root. * Code cleanup: remove redundant comments, rename some variables. * Tighten up code a bit more; use Py_XDECREF to simplify error handling. * Update Modules/mathmodule.c Co-Authored-By: Serhiy Storchaka * Update Modules/mathmodule.c Use real argument clinic type instead of an alias Co-Authored-By: Serhiy Storchaka * Add proof sketch * Updates from review. * Correct and expand documentation. * Fix bad reference handling on error; make some variables block-local; other tidying. * Style and consistency fixes. * Add missing error check; don't try to DECREF a NULL a * Simplify some error returns. * Another two test cases: - clarify that floats are rejected even if they happen to be squares of small integers - TypeError beats ValueError for a negative float * Documentation and markup improvements; thanks Serhiy for the suggestions! * Cleaner Misc/NEWS entry wording. * Clean up (with one fix) to the algorithm explanation and proof. --- Doc/library/math.rst | 17 ++ Doc/whatsnew/3.8.rst | 3 + Lib/test/test_math.py | 51 ++++ .../2019-05-11-14-50-59.bpo-36887.XD3f22.rst | 1 + Modules/clinic/mathmodule.c.h | 11 +- Modules/mathmodule.c | 261 ++++++++++++++++++ 6 files changed, 343 insertions(+), 1 deletion(-) create mode 100644 Misc/NEWS.d/next/Library/2019-05-11-14-50-59.bpo-36887.XD3f22.rst diff --git a/Doc/library/math.rst b/Doc/library/math.rst index 49f932d0384..bf660ae9def 100644 --- a/Doc/library/math.rst +++ b/Doc/library/math.rst @@ -166,6 +166,20 @@ Number-theoretic and representation functions Return ``True`` if *x* is a NaN (not a number), and ``False`` otherwise. +.. function:: isqrt(n) + + Return the integer square root of the nonnegative integer *n*. This is the + floor of the exact square root of *n*, or equivalently the greatest integer + *a* such that *a*\ ² |nbsp| ≤ |nbsp| *n*. + + For some applications, it may be more convenient to have the least integer + *a* such that *n* |nbsp| ≤ |nbsp| *a*\ ², or in other words the ceiling of + the exact square root of *n*. For positive *n*, this can be computed using + ``a = 1 + isqrt(n - 1)``. + + .. versionadded:: 3.8 + + .. function:: ldexp(x, i) Return ``x * (2**i)``. This is essentially the inverse of function @@ -538,3 +552,6 @@ Constants Module :mod:`cmath` Complex number versions of many of these functions. + +.. |nbsp| unicode:: 0xA0 + :trim: diff --git a/Doc/whatsnew/3.8.rst b/Doc/whatsnew/3.8.rst index d47993bf112..07da4047a38 100644 --- a/Doc/whatsnew/3.8.rst +++ b/Doc/whatsnew/3.8.rst @@ -344,6 +344,9 @@ Added new function, :func:`math.prod`, as analogous function to :func:`sum` that returns the product of a 'start' value (default: 1) times an iterable of numbers. (Contributed by Pablo Galindo in :issue:`35606`) +Added new function :func:`math.isqrt` for computing integer square roots. +(Contributed by Mark Dickinson in :issue:`36887`.) + os -- diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py index cb05dee0e0f..a11a3447856 100644 --- a/Lib/test/test_math.py +++ b/Lib/test/test_math.py @@ -912,6 +912,57 @@ class MathTests(unittest.TestCase): self.assertEqual(math.dist(p, q), 5*scale) self.assertEqual(math.dist(q, p), 5*scale) + def testIsqrt(self): + # Test a variety of inputs, large and small. + test_values = ( + list(range(1000)) + + list(range(10**6 - 1000, 10**6 + 1000)) + + [3**9999, 10**5001] + ) + + for value in test_values: + with self.subTest(value=value): + s = math.isqrt(value) + self.assertIs(type(s), int) + self.assertLessEqual(s*s, value) + self.assertLess(value, (s+1)*(s+1)) + + # Negative values + with self.assertRaises(ValueError): + math.isqrt(-1) + + # Integer-like things + s = math.isqrt(True) + self.assertIs(type(s), int) + self.assertEqual(s, 1) + + s = math.isqrt(False) + self.assertIs(type(s), int) + self.assertEqual(s, 0) + + class IntegerLike(object): + def __init__(self, value): + self.value = value + + def __index__(self): + return self.value + + s = math.isqrt(IntegerLike(1729)) + self.assertIs(type(s), int) + self.assertEqual(s, 41) + + with self.assertRaises(ValueError): + math.isqrt(IntegerLike(-3)) + + # Non-integer-like things + bad_values = [ + 3.5, "a string", decimal.Decimal("3.5"), 3.5j, + 100.0, -4.0, + ] + for value in bad_values: + with self.subTest(value=value): + with self.assertRaises(TypeError): + math.isqrt(value) def testLdexp(self): self.assertRaises(TypeError, math.ldexp) diff --git a/Misc/NEWS.d/next/Library/2019-05-11-14-50-59.bpo-36887.XD3f22.rst b/Misc/NEWS.d/next/Library/2019-05-11-14-50-59.bpo-36887.XD3f22.rst new file mode 100644 index 00000000000..fe2929cea85 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2019-05-11-14-50-59.bpo-36887.XD3f22.rst @@ -0,0 +1 @@ +Add new function :func:`math.isqrt` to compute integer square roots. diff --git a/Modules/clinic/mathmodule.c.h b/Modules/clinic/mathmodule.c.h index 1806a01588c..e677bd896cd 100644 --- a/Modules/clinic/mathmodule.c.h +++ b/Modules/clinic/mathmodule.c.h @@ -65,6 +65,15 @@ PyDoc_STRVAR(math_fsum__doc__, #define MATH_FSUM_METHODDEF \ {"fsum", (PyCFunction)math_fsum, METH_O, math_fsum__doc__}, +PyDoc_STRVAR(math_isqrt__doc__, +"isqrt($module, n, /)\n" +"--\n" +"\n" +"Return the integer part of the square root of the input."); + +#define MATH_ISQRT_METHODDEF \ + {"isqrt", (PyCFunction)math_isqrt, METH_O, math_isqrt__doc__}, + PyDoc_STRVAR(math_factorial__doc__, "factorial($module, x, /)\n" "--\n" @@ -628,4 +637,4 @@ skip_optional_kwonly: exit: return return_value; } -/*[clinic end generated code: output=96e71135dce41c48 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=aeed62f403b90199 input=a9049054013a1b77]*/ diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index 8f6a303cc4d..821309221f8 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -1476,6 +1476,266 @@ count_set_bits(unsigned long n) return count; } +/* Integer square root + +Given a nonnegative integer `n`, we want to compute the largest integer +`a` for which `a * a <= n`, or equivalently the integer part of the exact +square root of `n`. + +We use an adaptive-precision pure-integer version of Newton's iteration. Given +a positive integer `n`, the algorithm produces at each iteration an integer +approximation `a` to the square root of `n >> s` for some even integer `s`, +with `s` decreasing as the iterations progress. On the final iteration, `s` is +zero and we have an approximation to the square root of `n` itself. + +At every step, the approximation `a` is strictly within 1.0 of the true square +root, so we have + + (a - 1)**2 < (n >> s) < (a + 1)**2 + +After the final iteration, a check-and-correct step is needed to determine +whether `a` or `a - 1` gives the desired integer square root of `n`. + +The algorithm is remarkable in its simplicity. There's no need for a +per-iteration check-and-correct step, and termination is straightforward: the +number of iterations is known in advance (it's exactly `floor(log2(log2(n)))` +for `n > 1`). The only tricky part of the correctness proof is in establishing +that the bound `(a - 1)**2 < (n >> s) < (a + 1)**2` is maintained from one +iteration to the next. A sketch of the proof of this is given below. + +In addition to the proof sketch, a formal, computer-verified proof +of correctness (using Lean) of an equivalent recursive algorithm can be found +here: + + https://github.com/mdickinson/snippets/blob/master/proofs/isqrt/src/isqrt.lean + + +Here's Python code equivalent to the C implementation below: + + def isqrt(n): + """ + Return the integer part of the square root of the input. + """ + n = operator.index(n) + + if n < 0: + raise ValueError("isqrt() argument must be nonnegative") + if n == 0: + return 0 + + c = (n.bit_length() - 1) // 2 + a = 1 + d = 0 + for s in reversed(range(c.bit_length())): + e = d + d = c >> s + a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a + assert (a-1)**2 < n >> 2*(c - d) < (a+1)**2 + + return a - (a*a > n) + + +Sketch of proof of correctness +------------------------------ + +The delicate part of the correctness proof is showing that the loop invariant +is preserved from one iteration to the next. That is, just before the line + + a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a + +is executed in the above code, we know that + + (1) (a - 1)**2 < (n >> 2*(c - e)) < (a + 1)**2. + +(since `e` is always the value of `d` from the previous iteration). We must +prove that after that line is executed, we have + + (a - 1)**2 < (n >> 2*(c - d)) < (a + 1)**2 + +To faciliate the proof, we make some changes of notation. Write `m` for +`n >> 2*(c-d)`, and write `b` for the new value of `a`, so + + b = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a + +or equivalently: + + (2) b = (a << d - e - 1) + (m >> d - e + 1) // a + +Then we can rewrite (1) as: + + (3) (a - 1)**2 < (m >> 2*(d - e)) < (a + 1)**2 + +and we must show that (b - 1)**2 < m < (b + 1)**2. + +From this point on, we switch to mathematical notation, so `/` means exact +division rather than integer division and `^` is used for exponentiation. We +use the `√` symbol for the exact square root. In (3), we can remove the +implicit floor operation to give: + + (4) (a - 1)^2 < m / 4^(d - e) < (a + 1)^2 + +Taking square roots throughout (4), scaling by `2^(d-e)`, and rearranging gives + + (5) 0 <= | 2^(d-e)a - √m | < 2^(d-e) + +Squaring and dividing through by `2^(d-e+1) a` gives + + (6) 0 <= 2^(d-e-1) a + m / (2^(d-e+1) a) - √m < 2^(d-e-1) / a + +We'll show below that `2^(d-e-1) <= a`. Given that, we can replace the +right-hand side of (6) with `1`, and now replacing the central +term `m / (2^(d-e+1) a)` with its floor in (6) gives + + (7) -1 < 2^(d-e-1) a + m // 2^(d-e+1) a - √m < 1 + +Or equivalently, from (2): + + (7) -1 < b - √m < 1 + +and rearranging gives that `(b-1)^2 < m < (b+1)^2`, which is what we needed +to prove. + +We're not quite done: we still have to prove the inequality `2^(d - e - 1) <= +a` that was used to get line (7) above. From the definition of `c`, we have +`4^c <= n`, which implies + + (8) 4^d <= m + +also, since `e == d >> 1`, `d` is at most `2e + 1`, from which it follows +that `2d - 2e - 1 <= d` and hence that + + (9) 4^(2d - 2e - 1) <= m + +Dividing both sides by `4^(d - e)` gives + + (10) 4^(d - e - 1) <= m / 4^(d - e) + +But we know from (4) that `m / 4^(d-e) < (a + 1)^2`, hence + + (11) 4^(d - e - 1) < (a + 1)^2 + +Now taking square roots of both sides and observing that both `2^(d-e-1)` and +`a` are integers gives `2^(d - e - 1) <= a`, which is what we needed. This +completes the proof sketch. + +*/ + +/*[clinic input] +math.isqrt + + n: object + / + +Return the integer part of the square root of the input. +[clinic start generated code]*/ + +static PyObject * +math_isqrt(PyObject *module, PyObject *n) +/*[clinic end generated code: output=35a6f7f980beab26 input=5b6e7ae4fa6c43d6]*/ +{ + int a_too_large, s; + size_t c, d; + PyObject *a = NULL, *b; + + n = PyNumber_Index(n); + if (n == NULL) { + return NULL; + } + + if (_PyLong_Sign(n) < 0) { + PyErr_SetString( + PyExc_ValueError, + "isqrt() argument must be nonnegative"); + goto error; + } + if (_PyLong_Sign(n) == 0) { + Py_DECREF(n); + return PyLong_FromLong(0); + } + + c = _PyLong_NumBits(n); + if (c == (size_t)(-1)) { + goto error; + } + c = (c - 1U) / 2U; + + /* s = c.bit_length() */ + s = 0; + while ((c >> s) > 0) { + ++s; + } + + a = PyLong_FromLong(1); + if (a == NULL) { + goto error; + } + d = 0; + while (--s >= 0) { + PyObject *q, *shift; + size_t e = d; + + d = c >> s; + + /* q = (n >> 2*c - e - d + 1) // a */ + shift = PyLong_FromSize_t(2U*c - d - e + 1U); + if (shift == NULL) { + goto error; + } + q = PyNumber_Rshift(n, shift); + Py_DECREF(shift); + if (q == NULL) { + goto error; + } + Py_SETREF(q, PyNumber_FloorDivide(q, a)); + if (q == NULL) { + goto error; + } + + /* a = (a << d - 1 - e) + q */ + shift = PyLong_FromSize_t(d - 1U - e); + if (shift == NULL) { + Py_DECREF(q); + goto error; + } + Py_SETREF(a, PyNumber_Lshift(a, shift)); + Py_DECREF(shift); + if (a == NULL) { + Py_DECREF(q); + goto error; + } + Py_SETREF(a, PyNumber_Add(a, q)); + Py_DECREF(q); + if (a == NULL) { + goto error; + } + } + + /* The correct result is either a or a - 1. Figure out which, and + decrement a if necessary. */ + + /* a_too_large = n < a * a */ + b = PyNumber_Multiply(a, a); + if (b == NULL) { + goto error; + } + a_too_large = PyObject_RichCompareBool(n, b, Py_LT); + Py_DECREF(b); + if (a_too_large == -1) { + goto error; + } + + if (a_too_large) { + Py_SETREF(a, PyNumber_Subtract(a, _PyLong_One)); + } + Py_DECREF(n); + return a; + + error: + Py_XDECREF(a); + Py_DECREF(n); + return NULL; +} + /* Divide-and-conquer factorial algorithm * * Based on the formula and pseudo-code provided at: @@ -2737,6 +2997,7 @@ static PyMethodDef math_methods[] = { MATH_ISFINITE_METHODDEF MATH_ISINF_METHODDEF MATH_ISNAN_METHODDEF + MATH_ISQRT_METHODDEF MATH_LDEXP_METHODDEF {"lgamma", math_lgamma, METH_O, math_lgamma_doc}, MATH_LOG_METHODDEF