gh-118610: Centralize power caching in `_pylong.py` (#118611)

A new `compute_powers()` function computes all and only the powers of the base the various base-conversion functions need, as efficiently as reasonably possible (turns out that invoking `**`is needed at most once). This typically gives a few % speedup, but the primary point is to simplify the base-conversion functions, which no longer need their own, ad hoc, and less efficient power-caching schemes.

Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
This commit is contained in:
Tim Peters 2024-05-07 19:09:09 -05:00 committed by GitHub
parent 2a85bed89d
commit 2f0a338be6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 113 additions and 67 deletions

View File

@ -19,6 +19,86 @@ try:
except ImportError:
_decimal = None
# A number of functions have this form, where `w` is a desired number of
# digits in base `base`:
#
# def inner(...w...):
# if w <= LIMIT:
# return something
# lo = w >> 1
# hi = w - lo
# something involving base**lo, inner(...lo...), j, and inner(...hi...)
# figure out largest w needed
# result = inner(w)
#
# They all had some on-the-fly scheme to cache `base**lo` results for reuse.
# Power is costly.
#
# This routine aims to compute all amd only the needed powers in advance, as
# efficiently as reasonably possible. This isn't trivial, and all the
# on-the-fly methods did needless work in many cases. The driving code above
# changes to:
#
# figure out largest w needed
# mycache = compute_powers(w, base, LIMIT)
# result = inner(w)
#
# and `mycache[lo]` replaces `base**lo` in the inner function.
#
# While this does give minor speedups (a few percent at best), the primary
# intent is to simplify the functions using this, by eliminating the need for
# them to craft their own ad-hoc caching schemes.
def compute_powers(w, base, more_than, show=False):
seen = set()
need = set()
ws = {w}
while ws:
w = ws.pop() # any element is fine to use next
if w in seen or w <= more_than:
continue
seen.add(w)
lo = w >> 1
# only _need_ lo here; some other path may, or may not, need hi
need.add(lo)
ws.add(lo)
if w & 1:
ws.add(lo + 1)
d = {}
if not need:
return d
it = iter(sorted(need))
first = next(it)
if show:
print("pow at", first)
d[first] = base ** first
for this in it:
if this - 1 in d:
if show:
print("* base at", this)
d[this] = d[this - 1] * base # cheap
else:
lo = this >> 1
hi = this - lo
assert lo in d
if show:
print("square at", this)
# Multiplying a bigint by itself (same object!) is about twice
# as fast in CPython.
sq = d[lo] * d[lo]
if hi != lo:
assert hi == lo + 1
if show:
print(" and * base")
sq *= base
d[this] = sq
return d
_unbounded_dec_context = decimal.getcontext().copy()
_unbounded_dec_context.prec = decimal.MAX_PREC
_unbounded_dec_context.Emax = decimal.MAX_EMAX
_unbounded_dec_context.Emin = decimal.MIN_EMIN
_unbounded_dec_context.traps[decimal.Inexact] = 1 # sanity check
def int_to_decimal(n):
"""Asymptotically fast conversion of an 'int' to Decimal."""
@ -33,57 +113,32 @@ def int_to_decimal(n):
# "clever" recursive way. If we want a string representation, we
# apply str to _that_.
D = decimal.Decimal
D2 = D(2)
BITLIM = 128
mem = {}
def w2pow(w):
"""Return D(2)**w and store the result. Also possibly save some
intermediate results. In context, these are likely to be reused
across various levels of the conversion to Decimal."""
if (result := mem.get(w)) is None:
if w <= BITLIM:
result = D2**w
elif w - 1 in mem:
result = (t := mem[w - 1]) + t
else:
w2 = w >> 1
# If w happens to be odd, w-w2 is one larger then w2
# now. Recurse on the smaller first (w2), so that it's
# in the cache and the larger (w-w2) can be handled by
# the cheaper `w-1 in mem` branch instead.
result = w2pow(w2) * w2pow(w - w2)
mem[w] = result
return result
from decimal import Decimal as D
BITLIM = 200
# Don't bother caching the "lo" mask in this; the time to compute it is
# tiny compared to the multiply.
def inner(n, w):
if w <= BITLIM:
return D(n)
w2 = w >> 1
hi = n >> w2
lo = n - (hi << w2)
return inner(lo, w2) + inner(hi, w - w2) * w2pow(w2)
with decimal.localcontext() as ctx:
ctx.prec = decimal.MAX_PREC
ctx.Emax = decimal.MAX_EMAX
ctx.Emin = decimal.MIN_EMIN
ctx.traps[decimal.Inexact] = 1
lo = n & ((1 << w2) - 1)
return inner(lo, w2) + inner(hi, w - w2) * w2pow[w2]
with decimal.localcontext(_unbounded_dec_context):
nbits = n.bit_length()
w2pow = compute_powers(nbits, D(2), BITLIM)
if n < 0:
negate = True
n = -n
else:
negate = False
result = inner(n, n.bit_length())
result = inner(n, nbits)
if negate:
result = -result
return result
def int_to_decimal_string(n):
"""Asymptotically fast conversion of an 'int' to a decimal string."""
w = n.bit_length()
@ -97,14 +152,13 @@ def int_to_decimal_string(n):
# available. This algorithm is asymptotically worse than the algorithm
# using the decimal module, but better than the quadratic time
# implementation in longobject.c.
DIGLIM = 1000
def inner(n, w):
if w <= 1000:
if w <= DIGLIM:
return str(n)
w2 = w >> 1
d = pow10_cache.get(w2)
if d is None:
d = pow10_cache[w2] = 5**w2 << w2 # 10**i = (5*2)**i = 5**i * 2**i
hi, lo = divmod(n, d)
hi, lo = divmod(n, pow10[w2])
return inner(hi, w - w2) + inner(lo, w2).zfill(w2)
# The estimation of the number of decimal digits.
@ -115,7 +169,9 @@ def int_to_decimal_string(n):
# only if the number has way more than 10**15 digits, that exceeds
# the 52-bit physical address limit in both Intel64 and AMD64.
w = int(w * 0.3010299956639812 + 1) # log10(2)
pow10_cache = {}
pow10 = compute_powers(w, 5, DIGLIM)
for k, v in pow10.items():
pow10[k] = v << k # 5**k << k == 5**k * 2**k == 10**k
if n < 0:
n = -n
sign = '-'
@ -128,7 +184,6 @@ def int_to_decimal_string(n):
s = s.lstrip('0')
return sign + s
def _str_to_int_inner(s):
"""Asymptotically fast conversion of a 'str' to an 'int'."""
@ -144,35 +199,15 @@ def _str_to_int_inner(s):
DIGLIM = 2048
mem = {}
def w5pow(w):
"""Return 5**w and store the result.
Also possibly save some intermediate results. In context, these
are likely to be reused across various levels of the conversion
to 'int'.
"""
if (result := mem.get(w)) is None:
if w <= DIGLIM:
result = 5**w
elif w - 1 in mem:
result = mem[w - 1] * 5
else:
w2 = w >> 1
# If w happens to be odd, w-w2 is one larger then w2
# now. Recurse on the smaller first (w2), so that it's
# in the cache and the larger (w-w2) can be handled by
# the cheaper `w-1 in mem` branch instead.
result = w5pow(w2) * w5pow(w - w2)
mem[w] = result
return result
def inner(a, b):
if b - a <= DIGLIM:
return int(s[a:b])
mid = (a + b + 1) >> 1
return inner(mid, b) + ((inner(a, mid) * w5pow(b - mid)) << (b - mid))
return (inner(mid, b)
+ ((inner(a, mid) * w5pow[b - mid])
<< (b - mid)))
w5pow = compute_powers(len(s), 5, DIGLIM)
return inner(0, len(s))
@ -186,7 +221,6 @@ def int_from_string(s):
s = s.rstrip().replace('_', '')
return _str_to_int_inner(s)
def str_to_int(s):
"""Asymptotically fast version of decimal string to 'int' conversion."""
# FIXME: this doesn't support the full syntax that int() supports.

View File

@ -906,6 +906,18 @@ class PyLongModuleTests(unittest.TestCase):
with self.assertRaises(RuntimeError):
int(big_value)
def test_pylong_roundtrip(self):
from random import randrange, getrandbits
bits = 5000
while bits <= 1_000_000:
bits += randrange(-100, 101) # break bitlength patterns
hibit = 1 << (bits - 1)
n = hibit | getrandbits(bits - 1)
assert n.bit_length() == bits
sn = str(n)
self.assertFalse(sn.startswith('0'))
self.assertEqual(n, int(sn))
bits <<= 1
if __name__ == "__main__":
unittest.main()