bpo-46233: Minor speedup for bigint squaring (GH-30345)

x_mul()'s squaring code can do some redundant and/or useless
work at the end of each digit pass. A more careful analysis
of worst-case carries at various digit positions allows
making that code leaner.
This commit is contained in:
Tim Peters 2022-01-03 20:41:16 -06:00 committed by GitHub
parent f1a58441ee
commit 3aa5242b54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 5 deletions

View File

@ -1502,6 +1502,17 @@ class LongTest(unittest.TestCase):
self.assertEqual(type(numerator), int) self.assertEqual(type(numerator), int)
self.assertEqual(type(denominator), int) self.assertEqual(type(denominator), int)
def test_square(self):
# Multiplication makes a special case of multiplying an int with
# itself, using a special, faster algorithm. This test is mostly
# to ensure that no asserts in the implementation trigger, in
# cases with a maximal amount of carries.
for bitlen in range(1, 400):
n = (1 << bitlen) - 1 # solid string of 1 bits
with self.subTest(bitlen=bitlen, n=n):
# (2**i - 1)**2 = 2**(2*i) - 2*2**i + 1
self.assertEqual(n**2,
(1 << (2 * bitlen)) - (1 << (bitlen + 1)) + 1)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -3237,12 +3237,12 @@ x_mul(PyLongObject *a, PyLongObject *b)
* via exploiting that each entry in the multiplication * via exploiting that each entry in the multiplication
* pyramid appears twice (except for the size_a squares). * pyramid appears twice (except for the size_a squares).
*/ */
digit *paend = a->ob_digit + size_a;
for (i = 0; i < size_a; ++i) { for (i = 0; i < size_a; ++i) {
twodigits carry; twodigits carry;
twodigits f = a->ob_digit[i]; twodigits f = a->ob_digit[i];
digit *pz = z->ob_digit + (i << 1); digit *pz = z->ob_digit + (i << 1);
digit *pa = a->ob_digit + i + 1; digit *pa = a->ob_digit + i + 1;
digit *paend = a->ob_digit + size_a;
SIGCHECK({ SIGCHECK({
Py_DECREF(z); Py_DECREF(z);
@ -3265,13 +3265,27 @@ x_mul(PyLongObject *a, PyLongObject *b)
assert(carry <= (PyLong_MASK << 1)); assert(carry <= (PyLong_MASK << 1));
} }
if (carry) { if (carry) {
/* See comment below. pz points at the highest possible
* carry position from the last outer loop iteration, so
* *pz is at most 1.
*/
assert(*pz <= 1);
carry += *pz; carry += *pz;
*pz++ = (digit)(carry & PyLong_MASK); *pz = (digit)(carry & PyLong_MASK);
carry >>= PyLong_SHIFT; carry >>= PyLong_SHIFT;
if (carry) {
/* If there's still a carry, it must be into a position
* that still holds a 0. Where the base
^ B is 1 << PyLong_SHIFT, the last add was of a carry no
* more than 2*B - 2 to a stored digit no more than 1.
* So the sum was no more than 2*B - 1, so the current
* carry no more than floor((2*B - 1)/B) = 1.
*/
assert(carry == 1);
assert(pz[1] == 0);
pz[1] = (digit)carry;
}
} }
if (carry)
*pz += (digit)(carry & PyLong_MASK);
assert((carry >> PyLong_SHIFT) == 0);
} }
} }
else { /* a is not the same as b -- gradeschool int mult */ else { /* a is not the same as b -- gradeschool int mult */