From 3aa5242b54b0627293d95cfb4a26b2f917f667be Mon Sep 17 00:00:00 2001 From: Tim Peters Date: Mon, 3 Jan 2022 20:41:16 -0600 Subject: [PATCH] bpo-46233: Minor speedup for bigint squaring (GH-30345) x_mul()'s squaring code can do some redundant and/or useless work at the end of each digit pass. A more careful analysis of worst-case carries at various digit positions allows making that code leaner. --- Lib/test/test_long.py | 11 +++++++++++ Objects/longobject.c | 24 +++++++++++++++++++----- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/Lib/test/test_long.py b/Lib/test/test_long.py index 3c8e9e22e17..f2a622b5868 100644 --- a/Lib/test/test_long.py +++ b/Lib/test/test_long.py @@ -1502,6 +1502,17 @@ class LongTest(unittest.TestCase): self.assertEqual(type(numerator), int) self.assertEqual(type(denominator), int) + def test_square(self): + # Multiplication makes a special case of multiplying an int with + # itself, using a special, faster algorithm. This test is mostly + # to ensure that no asserts in the implementation trigger, in + # cases with a maximal amount of carries. + for bitlen in range(1, 400): + n = (1 << bitlen) - 1 # solid string of 1 bits + with self.subTest(bitlen=bitlen, n=n): + # (2**i - 1)**2 = 2**(2*i) - 2*2**i + 1 + self.assertEqual(n**2, + (1 << (2 * bitlen)) - (1 << (bitlen + 1)) + 1) if __name__ == "__main__": unittest.main() diff --git a/Objects/longobject.c b/Objects/longobject.c index b5648fca7dc..2db8701a841 100644 --- a/Objects/longobject.c +++ b/Objects/longobject.c @@ -3237,12 +3237,12 @@ x_mul(PyLongObject *a, PyLongObject *b) * via exploiting that each entry in the multiplication * pyramid appears twice (except for the size_a squares). */ + digit *paend = a->ob_digit + size_a; for (i = 0; i < size_a; ++i) { twodigits carry; twodigits f = a->ob_digit[i]; digit *pz = z->ob_digit + (i << 1); digit *pa = a->ob_digit + i + 1; - digit *paend = a->ob_digit + size_a; SIGCHECK({ Py_DECREF(z); @@ -3265,13 +3265,27 @@ x_mul(PyLongObject *a, PyLongObject *b) assert(carry <= (PyLong_MASK << 1)); } if (carry) { + /* See comment below. pz points at the highest possible + * carry position from the last outer loop iteration, so + * *pz is at most 1. + */ + assert(*pz <= 1); carry += *pz; - *pz++ = (digit)(carry & PyLong_MASK); + *pz = (digit)(carry & PyLong_MASK); carry >>= PyLong_SHIFT; + if (carry) { + /* If there's still a carry, it must be into a position + * that still holds a 0. Where the base + ^ B is 1 << PyLong_SHIFT, the last add was of a carry no + * more than 2*B - 2 to a stored digit no more than 1. + * So the sum was no more than 2*B - 1, so the current + * carry no more than floor((2*B - 1)/B) = 1. + */ + assert(carry == 1); + assert(pz[1] == 0); + pz[1] = (digit)carry; + } } - if (carry) - *pz += (digit)(carry & PyLong_MASK); - assert((carry >> PyLong_SHIFT) == 0); } } else { /* a is not the same as b -- gradeschool int mult */