bpo-36887: add math.isqrt (GH-13244)
* Add math.isqrt function computing the integer square root. * Code cleanup: remove redundant comments, rename some variables. * Tighten up code a bit more; use Py_XDECREF to simplify error handling. * Update Modules/mathmodule.c Co-Authored-By: Serhiy Storchaka <storchaka@gmail.com> * Update Modules/mathmodule.c Use real argument clinic type instead of an alias Co-Authored-By: Serhiy Storchaka <storchaka@gmail.com> * Add proof sketch * Updates from review. * Correct and expand documentation. * Fix bad reference handling on error; make some variables block-local; other tidying. * Style and consistency fixes. * Add missing error check; don't try to DECREF a NULL a * Simplify some error returns. * Another two test cases: - clarify that floats are rejected even if they happen to be squares of small integers - TypeError beats ValueError for a negative float * Documentation and markup improvements; thanks Serhiy for the suggestions! * Cleaner Misc/NEWS entry wording. * Clean up (with one fix) to the algorithm explanation and proof.
This commit is contained in:
parent
410759fba8
commit
73934b9da0
|
@ -166,6 +166,20 @@ Number-theoretic and representation functions
|
||||||
Return ``True`` if *x* is a NaN (not a number), and ``False`` otherwise.
|
Return ``True`` if *x* is a NaN (not a number), and ``False`` otherwise.
|
||||||
|
|
||||||
|
|
||||||
|
.. function:: isqrt(n)
|
||||||
|
|
||||||
|
Return the integer square root of the nonnegative integer *n*. This is the
|
||||||
|
floor of the exact square root of *n*, or equivalently the greatest integer
|
||||||
|
*a* such that *a*\ ² |nbsp| ≤ |nbsp| *n*.
|
||||||
|
|
||||||
|
For some applications, it may be more convenient to have the least integer
|
||||||
|
*a* such that *n* |nbsp| ≤ |nbsp| *a*\ ², or in other words the ceiling of
|
||||||
|
the exact square root of *n*. For positive *n*, this can be computed using
|
||||||
|
``a = 1 + isqrt(n - 1)``.
|
||||||
|
|
||||||
|
.. versionadded:: 3.8
|
||||||
|
|
||||||
|
|
||||||
.. function:: ldexp(x, i)
|
.. function:: ldexp(x, i)
|
||||||
|
|
||||||
Return ``x * (2**i)``. This is essentially the inverse of function
|
Return ``x * (2**i)``. This is essentially the inverse of function
|
||||||
|
@ -538,3 +552,6 @@ Constants
|
||||||
|
|
||||||
Module :mod:`cmath`
|
Module :mod:`cmath`
|
||||||
Complex number versions of many of these functions.
|
Complex number versions of many of these functions.
|
||||||
|
|
||||||
|
.. |nbsp| unicode:: 0xA0
|
||||||
|
:trim:
|
||||||
|
|
|
@ -344,6 +344,9 @@ Added new function, :func:`math.prod`, as analogous function to :func:`sum`
|
||||||
that returns the product of a 'start' value (default: 1) times an iterable of
|
that returns the product of a 'start' value (default: 1) times an iterable of
|
||||||
numbers. (Contributed by Pablo Galindo in :issue:`35606`)
|
numbers. (Contributed by Pablo Galindo in :issue:`35606`)
|
||||||
|
|
||||||
|
Added new function :func:`math.isqrt` for computing integer square roots.
|
||||||
|
(Contributed by Mark Dickinson in :issue:`36887`.)
|
||||||
|
|
||||||
os
|
os
|
||||||
--
|
--
|
||||||
|
|
||||||
|
|
|
@ -912,6 +912,57 @@ class MathTests(unittest.TestCase):
|
||||||
self.assertEqual(math.dist(p, q), 5*scale)
|
self.assertEqual(math.dist(p, q), 5*scale)
|
||||||
self.assertEqual(math.dist(q, p), 5*scale)
|
self.assertEqual(math.dist(q, p), 5*scale)
|
||||||
|
|
||||||
|
def testIsqrt(self):
|
||||||
|
# Test a variety of inputs, large and small.
|
||||||
|
test_values = (
|
||||||
|
list(range(1000))
|
||||||
|
+ list(range(10**6 - 1000, 10**6 + 1000))
|
||||||
|
+ [3**9999, 10**5001]
|
||||||
|
)
|
||||||
|
|
||||||
|
for value in test_values:
|
||||||
|
with self.subTest(value=value):
|
||||||
|
s = math.isqrt(value)
|
||||||
|
self.assertIs(type(s), int)
|
||||||
|
self.assertLessEqual(s*s, value)
|
||||||
|
self.assertLess(value, (s+1)*(s+1))
|
||||||
|
|
||||||
|
# Negative values
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
math.isqrt(-1)
|
||||||
|
|
||||||
|
# Integer-like things
|
||||||
|
s = math.isqrt(True)
|
||||||
|
self.assertIs(type(s), int)
|
||||||
|
self.assertEqual(s, 1)
|
||||||
|
|
||||||
|
s = math.isqrt(False)
|
||||||
|
self.assertIs(type(s), int)
|
||||||
|
self.assertEqual(s, 0)
|
||||||
|
|
||||||
|
class IntegerLike(object):
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __index__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
s = math.isqrt(IntegerLike(1729))
|
||||||
|
self.assertIs(type(s), int)
|
||||||
|
self.assertEqual(s, 41)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
math.isqrt(IntegerLike(-3))
|
||||||
|
|
||||||
|
# Non-integer-like things
|
||||||
|
bad_values = [
|
||||||
|
3.5, "a string", decimal.Decimal("3.5"), 3.5j,
|
||||||
|
100.0, -4.0,
|
||||||
|
]
|
||||||
|
for value in bad_values:
|
||||||
|
with self.subTest(value=value):
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
math.isqrt(value)
|
||||||
|
|
||||||
def testLdexp(self):
|
def testLdexp(self):
|
||||||
self.assertRaises(TypeError, math.ldexp)
|
self.assertRaises(TypeError, math.ldexp)
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Add new function :func:`math.isqrt` to compute integer square roots.
|
|
@ -65,6 +65,15 @@ PyDoc_STRVAR(math_fsum__doc__,
|
||||||
#define MATH_FSUM_METHODDEF \
|
#define MATH_FSUM_METHODDEF \
|
||||||
{"fsum", (PyCFunction)math_fsum, METH_O, math_fsum__doc__},
|
{"fsum", (PyCFunction)math_fsum, METH_O, math_fsum__doc__},
|
||||||
|
|
||||||
|
PyDoc_STRVAR(math_isqrt__doc__,
|
||||||
|
"isqrt($module, n, /)\n"
|
||||||
|
"--\n"
|
||||||
|
"\n"
|
||||||
|
"Return the integer part of the square root of the input.");
|
||||||
|
|
||||||
|
#define MATH_ISQRT_METHODDEF \
|
||||||
|
{"isqrt", (PyCFunction)math_isqrt, METH_O, math_isqrt__doc__},
|
||||||
|
|
||||||
PyDoc_STRVAR(math_factorial__doc__,
|
PyDoc_STRVAR(math_factorial__doc__,
|
||||||
"factorial($module, x, /)\n"
|
"factorial($module, x, /)\n"
|
||||||
"--\n"
|
"--\n"
|
||||||
|
@ -628,4 +637,4 @@ skip_optional_kwonly:
|
||||||
exit:
|
exit:
|
||||||
return return_value;
|
return return_value;
|
||||||
}
|
}
|
||||||
/*[clinic end generated code: output=96e71135dce41c48 input=a9049054013a1b77]*/
|
/*[clinic end generated code: output=aeed62f403b90199 input=a9049054013a1b77]*/
|
||||||
|
|
|
@ -1476,6 +1476,266 @@ count_set_bits(unsigned long n)
|
||||||
return count;
|
return count;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Integer square root
|
||||||
|
|
||||||
|
Given a nonnegative integer `n`, we want to compute the largest integer
|
||||||
|
`a` for which `a * a <= n`, or equivalently the integer part of the exact
|
||||||
|
square root of `n`.
|
||||||
|
|
||||||
|
We use an adaptive-precision pure-integer version of Newton's iteration. Given
|
||||||
|
a positive integer `n`, the algorithm produces at each iteration an integer
|
||||||
|
approximation `a` to the square root of `n >> s` for some even integer `s`,
|
||||||
|
with `s` decreasing as the iterations progress. On the final iteration, `s` is
|
||||||
|
zero and we have an approximation to the square root of `n` itself.
|
||||||
|
|
||||||
|
At every step, the approximation `a` is strictly within 1.0 of the true square
|
||||||
|
root, so we have
|
||||||
|
|
||||||
|
(a - 1)**2 < (n >> s) < (a + 1)**2
|
||||||
|
|
||||||
|
After the final iteration, a check-and-correct step is needed to determine
|
||||||
|
whether `a` or `a - 1` gives the desired integer square root of `n`.
|
||||||
|
|
||||||
|
The algorithm is remarkable in its simplicity. There's no need for a
|
||||||
|
per-iteration check-and-correct step, and termination is straightforward: the
|
||||||
|
number of iterations is known in advance (it's exactly `floor(log2(log2(n)))`
|
||||||
|
for `n > 1`). The only tricky part of the correctness proof is in establishing
|
||||||
|
that the bound `(a - 1)**2 < (n >> s) < (a + 1)**2` is maintained from one
|
||||||
|
iteration to the next. A sketch of the proof of this is given below.
|
||||||
|
|
||||||
|
In addition to the proof sketch, a formal, computer-verified proof
|
||||||
|
of correctness (using Lean) of an equivalent recursive algorithm can be found
|
||||||
|
here:
|
||||||
|
|
||||||
|
https://github.com/mdickinson/snippets/blob/master/proofs/isqrt/src/isqrt.lean
|
||||||
|
|
||||||
|
|
||||||
|
Here's Python code equivalent to the C implementation below:
|
||||||
|
|
||||||
|
def isqrt(n):
|
||||||
|
"""
|
||||||
|
Return the integer part of the square root of the input.
|
||||||
|
"""
|
||||||
|
n = operator.index(n)
|
||||||
|
|
||||||
|
if n < 0:
|
||||||
|
raise ValueError("isqrt() argument must be nonnegative")
|
||||||
|
if n == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
c = (n.bit_length() - 1) // 2
|
||||||
|
a = 1
|
||||||
|
d = 0
|
||||||
|
for s in reversed(range(c.bit_length())):
|
||||||
|
e = d
|
||||||
|
d = c >> s
|
||||||
|
a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a
|
||||||
|
assert (a-1)**2 < n >> 2*(c - d) < (a+1)**2
|
||||||
|
|
||||||
|
return a - (a*a > n)
|
||||||
|
|
||||||
|
|
||||||
|
Sketch of proof of correctness
|
||||||
|
------------------------------
|
||||||
|
|
||||||
|
The delicate part of the correctness proof is showing that the loop invariant
|
||||||
|
is preserved from one iteration to the next. That is, just before the line
|
||||||
|
|
||||||
|
a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a
|
||||||
|
|
||||||
|
is executed in the above code, we know that
|
||||||
|
|
||||||
|
(1) (a - 1)**2 < (n >> 2*(c - e)) < (a + 1)**2.
|
||||||
|
|
||||||
|
(since `e` is always the value of `d` from the previous iteration). We must
|
||||||
|
prove that after that line is executed, we have
|
||||||
|
|
||||||
|
(a - 1)**2 < (n >> 2*(c - d)) < (a + 1)**2
|
||||||
|
|
||||||
|
To faciliate the proof, we make some changes of notation. Write `m` for
|
||||||
|
`n >> 2*(c-d)`, and write `b` for the new value of `a`, so
|
||||||
|
|
||||||
|
b = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a
|
||||||
|
|
||||||
|
or equivalently:
|
||||||
|
|
||||||
|
(2) b = (a << d - e - 1) + (m >> d - e + 1) // a
|
||||||
|
|
||||||
|
Then we can rewrite (1) as:
|
||||||
|
|
||||||
|
(3) (a - 1)**2 < (m >> 2*(d - e)) < (a + 1)**2
|
||||||
|
|
||||||
|
and we must show that (b - 1)**2 < m < (b + 1)**2.
|
||||||
|
|
||||||
|
From this point on, we switch to mathematical notation, so `/` means exact
|
||||||
|
division rather than integer division and `^` is used for exponentiation. We
|
||||||
|
use the `√` symbol for the exact square root. In (3), we can remove the
|
||||||
|
implicit floor operation to give:
|
||||||
|
|
||||||
|
(4) (a - 1)^2 < m / 4^(d - e) < (a + 1)^2
|
||||||
|
|
||||||
|
Taking square roots throughout (4), scaling by `2^(d-e)`, and rearranging gives
|
||||||
|
|
||||||
|
(5) 0 <= | 2^(d-e)a - √m | < 2^(d-e)
|
||||||
|
|
||||||
|
Squaring and dividing through by `2^(d-e+1) a` gives
|
||||||
|
|
||||||
|
(6) 0 <= 2^(d-e-1) a + m / (2^(d-e+1) a) - √m < 2^(d-e-1) / a
|
||||||
|
|
||||||
|
We'll show below that `2^(d-e-1) <= a`. Given that, we can replace the
|
||||||
|
right-hand side of (6) with `1`, and now replacing the central
|
||||||
|
term `m / (2^(d-e+1) a)` with its floor in (6) gives
|
||||||
|
|
||||||
|
(7) -1 < 2^(d-e-1) a + m // 2^(d-e+1) a - √m < 1
|
||||||
|
|
||||||
|
Or equivalently, from (2):
|
||||||
|
|
||||||
|
(7) -1 < b - √m < 1
|
||||||
|
|
||||||
|
and rearranging gives that `(b-1)^2 < m < (b+1)^2`, which is what we needed
|
||||||
|
to prove.
|
||||||
|
|
||||||
|
We're not quite done: we still have to prove the inequality `2^(d - e - 1) <=
|
||||||
|
a` that was used to get line (7) above. From the definition of `c`, we have
|
||||||
|
`4^c <= n`, which implies
|
||||||
|
|
||||||
|
(8) 4^d <= m
|
||||||
|
|
||||||
|
also, since `e == d >> 1`, `d` is at most `2e + 1`, from which it follows
|
||||||
|
that `2d - 2e - 1 <= d` and hence that
|
||||||
|
|
||||||
|
(9) 4^(2d - 2e - 1) <= m
|
||||||
|
|
||||||
|
Dividing both sides by `4^(d - e)` gives
|
||||||
|
|
||||||
|
(10) 4^(d - e - 1) <= m / 4^(d - e)
|
||||||
|
|
||||||
|
But we know from (4) that `m / 4^(d-e) < (a + 1)^2`, hence
|
||||||
|
|
||||||
|
(11) 4^(d - e - 1) < (a + 1)^2
|
||||||
|
|
||||||
|
Now taking square roots of both sides and observing that both `2^(d-e-1)` and
|
||||||
|
`a` are integers gives `2^(d - e - 1) <= a`, which is what we needed. This
|
||||||
|
completes the proof sketch.
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*[clinic input]
|
||||||
|
math.isqrt
|
||||||
|
|
||||||
|
n: object
|
||||||
|
/
|
||||||
|
|
||||||
|
Return the integer part of the square root of the input.
|
||||||
|
[clinic start generated code]*/
|
||||||
|
|
||||||
|
static PyObject *
|
||||||
|
math_isqrt(PyObject *module, PyObject *n)
|
||||||
|
/*[clinic end generated code: output=35a6f7f980beab26 input=5b6e7ae4fa6c43d6]*/
|
||||||
|
{
|
||||||
|
int a_too_large, s;
|
||||||
|
size_t c, d;
|
||||||
|
PyObject *a = NULL, *b;
|
||||||
|
|
||||||
|
n = PyNumber_Index(n);
|
||||||
|
if (n == NULL) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (_PyLong_Sign(n) < 0) {
|
||||||
|
PyErr_SetString(
|
||||||
|
PyExc_ValueError,
|
||||||
|
"isqrt() argument must be nonnegative");
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
if (_PyLong_Sign(n) == 0) {
|
||||||
|
Py_DECREF(n);
|
||||||
|
return PyLong_FromLong(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
c = _PyLong_NumBits(n);
|
||||||
|
if (c == (size_t)(-1)) {
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
c = (c - 1U) / 2U;
|
||||||
|
|
||||||
|
/* s = c.bit_length() */
|
||||||
|
s = 0;
|
||||||
|
while ((c >> s) > 0) {
|
||||||
|
++s;
|
||||||
|
}
|
||||||
|
|
||||||
|
a = PyLong_FromLong(1);
|
||||||
|
if (a == NULL) {
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
d = 0;
|
||||||
|
while (--s >= 0) {
|
||||||
|
PyObject *q, *shift;
|
||||||
|
size_t e = d;
|
||||||
|
|
||||||
|
d = c >> s;
|
||||||
|
|
||||||
|
/* q = (n >> 2*c - e - d + 1) // a */
|
||||||
|
shift = PyLong_FromSize_t(2U*c - d - e + 1U);
|
||||||
|
if (shift == NULL) {
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
q = PyNumber_Rshift(n, shift);
|
||||||
|
Py_DECREF(shift);
|
||||||
|
if (q == NULL) {
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
Py_SETREF(q, PyNumber_FloorDivide(q, a));
|
||||||
|
if (q == NULL) {
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* a = (a << d - 1 - e) + q */
|
||||||
|
shift = PyLong_FromSize_t(d - 1U - e);
|
||||||
|
if (shift == NULL) {
|
||||||
|
Py_DECREF(q);
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
Py_SETREF(a, PyNumber_Lshift(a, shift));
|
||||||
|
Py_DECREF(shift);
|
||||||
|
if (a == NULL) {
|
||||||
|
Py_DECREF(q);
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
Py_SETREF(a, PyNumber_Add(a, q));
|
||||||
|
Py_DECREF(q);
|
||||||
|
if (a == NULL) {
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* The correct result is either a or a - 1. Figure out which, and
|
||||||
|
decrement a if necessary. */
|
||||||
|
|
||||||
|
/* a_too_large = n < a * a */
|
||||||
|
b = PyNumber_Multiply(a, a);
|
||||||
|
if (b == NULL) {
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
a_too_large = PyObject_RichCompareBool(n, b, Py_LT);
|
||||||
|
Py_DECREF(b);
|
||||||
|
if (a_too_large == -1) {
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (a_too_large) {
|
||||||
|
Py_SETREF(a, PyNumber_Subtract(a, _PyLong_One));
|
||||||
|
}
|
||||||
|
Py_DECREF(n);
|
||||||
|
return a;
|
||||||
|
|
||||||
|
error:
|
||||||
|
Py_XDECREF(a);
|
||||||
|
Py_DECREF(n);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
/* Divide-and-conquer factorial algorithm
|
/* Divide-and-conquer factorial algorithm
|
||||||
*
|
*
|
||||||
* Based on the formula and pseudo-code provided at:
|
* Based on the formula and pseudo-code provided at:
|
||||||
|
@ -2737,6 +2997,7 @@ static PyMethodDef math_methods[] = {
|
||||||
MATH_ISFINITE_METHODDEF
|
MATH_ISFINITE_METHODDEF
|
||||||
MATH_ISINF_METHODDEF
|
MATH_ISINF_METHODDEF
|
||||||
MATH_ISNAN_METHODDEF
|
MATH_ISNAN_METHODDEF
|
||||||
|
MATH_ISQRT_METHODDEF
|
||||||
MATH_LDEXP_METHODDEF
|
MATH_LDEXP_METHODDEF
|
||||||
{"lgamma", math_lgamma, METH_O, math_lgamma_doc},
|
{"lgamma", math_lgamma, METH_O, math_lgamma_doc},
|
||||||
MATH_LOG_METHODDEF
|
MATH_LOG_METHODDEF
|
||||||
|
|
Loading…
Reference in New Issue