mirror of https://github.com/python/cpython
gh-118610: Centralize power caching in `_pylong.py` (#118611)
A new `compute_powers()` function computes all and only the powers of the base the various base-conversion functions need, as efficiently as reasonably possible (turns out that invoking `**`is needed at most once). This typically gives a few % speedup, but the primary point is to simplify the base-conversion functions, which no longer need their own, ad hoc, and less efficient power-caching schemes. Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
This commit is contained in:
parent
2a85bed89d
commit
2f0a338be6
168
Lib/_pylong.py
168
Lib/_pylong.py
|
@ -19,6 +19,86 @@ try:
|
|||
except ImportError:
|
||||
_decimal = None
|
||||
|
||||
# A number of functions have this form, where `w` is a desired number of
|
||||
# digits in base `base`:
|
||||
#
|
||||
# def inner(...w...):
|
||||
# if w <= LIMIT:
|
||||
# return something
|
||||
# lo = w >> 1
|
||||
# hi = w - lo
|
||||
# something involving base**lo, inner(...lo...), j, and inner(...hi...)
|
||||
# figure out largest w needed
|
||||
# result = inner(w)
|
||||
#
|
||||
# They all had some on-the-fly scheme to cache `base**lo` results for reuse.
|
||||
# Power is costly.
|
||||
#
|
||||
# This routine aims to compute all amd only the needed powers in advance, as
|
||||
# efficiently as reasonably possible. This isn't trivial, and all the
|
||||
# on-the-fly methods did needless work in many cases. The driving code above
|
||||
# changes to:
|
||||
#
|
||||
# figure out largest w needed
|
||||
# mycache = compute_powers(w, base, LIMIT)
|
||||
# result = inner(w)
|
||||
#
|
||||
# and `mycache[lo]` replaces `base**lo` in the inner function.
|
||||
#
|
||||
# While this does give minor speedups (a few percent at best), the primary
|
||||
# intent is to simplify the functions using this, by eliminating the need for
|
||||
# them to craft their own ad-hoc caching schemes.
|
||||
def compute_powers(w, base, more_than, show=False):
|
||||
seen = set()
|
||||
need = set()
|
||||
ws = {w}
|
||||
while ws:
|
||||
w = ws.pop() # any element is fine to use next
|
||||
if w in seen or w <= more_than:
|
||||
continue
|
||||
seen.add(w)
|
||||
lo = w >> 1
|
||||
# only _need_ lo here; some other path may, or may not, need hi
|
||||
need.add(lo)
|
||||
ws.add(lo)
|
||||
if w & 1:
|
||||
ws.add(lo + 1)
|
||||
|
||||
d = {}
|
||||
if not need:
|
||||
return d
|
||||
it = iter(sorted(need))
|
||||
first = next(it)
|
||||
if show:
|
||||
print("pow at", first)
|
||||
d[first] = base ** first
|
||||
for this in it:
|
||||
if this - 1 in d:
|
||||
if show:
|
||||
print("* base at", this)
|
||||
d[this] = d[this - 1] * base # cheap
|
||||
else:
|
||||
lo = this >> 1
|
||||
hi = this - lo
|
||||
assert lo in d
|
||||
if show:
|
||||
print("square at", this)
|
||||
# Multiplying a bigint by itself (same object!) is about twice
|
||||
# as fast in CPython.
|
||||
sq = d[lo] * d[lo]
|
||||
if hi != lo:
|
||||
assert hi == lo + 1
|
||||
if show:
|
||||
print(" and * base")
|
||||
sq *= base
|
||||
d[this] = sq
|
||||
return d
|
||||
|
||||
_unbounded_dec_context = decimal.getcontext().copy()
|
||||
_unbounded_dec_context.prec = decimal.MAX_PREC
|
||||
_unbounded_dec_context.Emax = decimal.MAX_EMAX
|
||||
_unbounded_dec_context.Emin = decimal.MIN_EMIN
|
||||
_unbounded_dec_context.traps[decimal.Inexact] = 1 # sanity check
|
||||
|
||||
def int_to_decimal(n):
|
||||
"""Asymptotically fast conversion of an 'int' to Decimal."""
|
||||
|
@ -33,57 +113,32 @@ def int_to_decimal(n):
|
|||
# "clever" recursive way. If we want a string representation, we
|
||||
# apply str to _that_.
|
||||
|
||||
D = decimal.Decimal
|
||||
D2 = D(2)
|
||||
|
||||
BITLIM = 128
|
||||
|
||||
mem = {}
|
||||
|
||||
def w2pow(w):
|
||||
"""Return D(2)**w and store the result. Also possibly save some
|
||||
intermediate results. In context, these are likely to be reused
|
||||
across various levels of the conversion to Decimal."""
|
||||
if (result := mem.get(w)) is None:
|
||||
if w <= BITLIM:
|
||||
result = D2**w
|
||||
elif w - 1 in mem:
|
||||
result = (t := mem[w - 1]) + t
|
||||
else:
|
||||
w2 = w >> 1
|
||||
# If w happens to be odd, w-w2 is one larger then w2
|
||||
# now. Recurse on the smaller first (w2), so that it's
|
||||
# in the cache and the larger (w-w2) can be handled by
|
||||
# the cheaper `w-1 in mem` branch instead.
|
||||
result = w2pow(w2) * w2pow(w - w2)
|
||||
mem[w] = result
|
||||
return result
|
||||
from decimal import Decimal as D
|
||||
BITLIM = 200
|
||||
|
||||
# Don't bother caching the "lo" mask in this; the time to compute it is
|
||||
# tiny compared to the multiply.
|
||||
def inner(n, w):
|
||||
if w <= BITLIM:
|
||||
return D(n)
|
||||
w2 = w >> 1
|
||||
hi = n >> w2
|
||||
lo = n - (hi << w2)
|
||||
return inner(lo, w2) + inner(hi, w - w2) * w2pow(w2)
|
||||
|
||||
with decimal.localcontext() as ctx:
|
||||
ctx.prec = decimal.MAX_PREC
|
||||
ctx.Emax = decimal.MAX_EMAX
|
||||
ctx.Emin = decimal.MIN_EMIN
|
||||
ctx.traps[decimal.Inexact] = 1
|
||||
lo = n & ((1 << w2) - 1)
|
||||
return inner(lo, w2) + inner(hi, w - w2) * w2pow[w2]
|
||||
|
||||
with decimal.localcontext(_unbounded_dec_context):
|
||||
nbits = n.bit_length()
|
||||
w2pow = compute_powers(nbits, D(2), BITLIM)
|
||||
if n < 0:
|
||||
negate = True
|
||||
n = -n
|
||||
else:
|
||||
negate = False
|
||||
result = inner(n, n.bit_length())
|
||||
result = inner(n, nbits)
|
||||
if negate:
|
||||
result = -result
|
||||
return result
|
||||
|
||||
|
||||
def int_to_decimal_string(n):
|
||||
"""Asymptotically fast conversion of an 'int' to a decimal string."""
|
||||
w = n.bit_length()
|
||||
|
@ -97,14 +152,13 @@ def int_to_decimal_string(n):
|
|||
# available. This algorithm is asymptotically worse than the algorithm
|
||||
# using the decimal module, but better than the quadratic time
|
||||
# implementation in longobject.c.
|
||||
|
||||
DIGLIM = 1000
|
||||
def inner(n, w):
|
||||
if w <= 1000:
|
||||
if w <= DIGLIM:
|
||||
return str(n)
|
||||
w2 = w >> 1
|
||||
d = pow10_cache.get(w2)
|
||||
if d is None:
|
||||
d = pow10_cache[w2] = 5**w2 << w2 # 10**i = (5*2)**i = 5**i * 2**i
|
||||
hi, lo = divmod(n, d)
|
||||
hi, lo = divmod(n, pow10[w2])
|
||||
return inner(hi, w - w2) + inner(lo, w2).zfill(w2)
|
||||
|
||||
# The estimation of the number of decimal digits.
|
||||
|
@ -115,7 +169,9 @@ def int_to_decimal_string(n):
|
|||
# only if the number has way more than 10**15 digits, that exceeds
|
||||
# the 52-bit physical address limit in both Intel64 and AMD64.
|
||||
w = int(w * 0.3010299956639812 + 1) # log10(2)
|
||||
pow10_cache = {}
|
||||
pow10 = compute_powers(w, 5, DIGLIM)
|
||||
for k, v in pow10.items():
|
||||
pow10[k] = v << k # 5**k << k == 5**k * 2**k == 10**k
|
||||
if n < 0:
|
||||
n = -n
|
||||
sign = '-'
|
||||
|
@ -128,7 +184,6 @@ def int_to_decimal_string(n):
|
|||
s = s.lstrip('0')
|
||||
return sign + s
|
||||
|
||||
|
||||
def _str_to_int_inner(s):
|
||||
"""Asymptotically fast conversion of a 'str' to an 'int'."""
|
||||
|
||||
|
@ -144,35 +199,15 @@ def _str_to_int_inner(s):
|
|||
|
||||
DIGLIM = 2048
|
||||
|
||||
mem = {}
|
||||
|
||||
def w5pow(w):
|
||||
"""Return 5**w and store the result.
|
||||
Also possibly save some intermediate results. In context, these
|
||||
are likely to be reused across various levels of the conversion
|
||||
to 'int'.
|
||||
"""
|
||||
if (result := mem.get(w)) is None:
|
||||
if w <= DIGLIM:
|
||||
result = 5**w
|
||||
elif w - 1 in mem:
|
||||
result = mem[w - 1] * 5
|
||||
else:
|
||||
w2 = w >> 1
|
||||
# If w happens to be odd, w-w2 is one larger then w2
|
||||
# now. Recurse on the smaller first (w2), so that it's
|
||||
# in the cache and the larger (w-w2) can be handled by
|
||||
# the cheaper `w-1 in mem` branch instead.
|
||||
result = w5pow(w2) * w5pow(w - w2)
|
||||
mem[w] = result
|
||||
return result
|
||||
|
||||
def inner(a, b):
|
||||
if b - a <= DIGLIM:
|
||||
return int(s[a:b])
|
||||
mid = (a + b + 1) >> 1
|
||||
return inner(mid, b) + ((inner(a, mid) * w5pow(b - mid)) << (b - mid))
|
||||
return (inner(mid, b)
|
||||
+ ((inner(a, mid) * w5pow[b - mid])
|
||||
<< (b - mid)))
|
||||
|
||||
w5pow = compute_powers(len(s), 5, DIGLIM)
|
||||
return inner(0, len(s))
|
||||
|
||||
|
||||
|
@ -186,7 +221,6 @@ def int_from_string(s):
|
|||
s = s.rstrip().replace('_', '')
|
||||
return _str_to_int_inner(s)
|
||||
|
||||
|
||||
def str_to_int(s):
|
||||
"""Asymptotically fast version of decimal string to 'int' conversion."""
|
||||
# FIXME: this doesn't support the full syntax that int() supports.
|
||||
|
|
|
@ -906,6 +906,18 @@ class PyLongModuleTests(unittest.TestCase):
|
|||
with self.assertRaises(RuntimeError):
|
||||
int(big_value)
|
||||
|
||||
def test_pylong_roundtrip(self):
|
||||
from random import randrange, getrandbits
|
||||
bits = 5000
|
||||
while bits <= 1_000_000:
|
||||
bits += randrange(-100, 101) # break bitlength patterns
|
||||
hibit = 1 << (bits - 1)
|
||||
n = hibit | getrandbits(bits - 1)
|
||||
assert n.bit_length() == bits
|
||||
sn = str(n)
|
||||
self.assertFalse(sn.startswith('0'))
|
||||
self.assertEqual(n, int(sn))
|
||||
bits <<= 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue