gh-118750: Asymptotically faster `int(string)` (#118751)

Asymptotically faster (O(n log n)) str->int for very large strings, leveraging the faster multiplication scheme in the C-coded `_decimal` when available. This is used instead of the current Karatsuba-limited method starting at 2 million digits.

Lots of opportunity remains for fine-tuning. Good targets include changing BYTELIM, and possibly changing the internal output base (from 256 to a higher number of bytes).

Doing this was substantial work, and many of the new lines are actually comments giving correctness proofs. The obvious approaches sticking to integers were too slow to be useful, so this is doing variable-precision decimal floating-point arithmetic. Much faster, but worst-possible rounding errors have to be wholly accounted for, using as little precision as possible.

Special thanks to Serhiy Storchaka for asking many good questions in his code reviews!

Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
Co-authored-by: sstandre <43125375+sstandre@users.noreply.github.com>
Co-authored-by: Pieter Eendebak <pieter.eendebak@gmail.com>
Co-authored-by: Nice Zombies <nineteendo19d0@gmail.com>
This commit is contained in:
Tim Peters 2024-05-18 19:19:57 -05:00 committed by GitHub
parent caf6064a1b
commit ecd8664f11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 479 additions and 33 deletions

View File

@ -45,10 +45,16 @@ except ImportError:
#
# and `mycache[lo]` replaces `base**lo` in the inner function.
#
# While this does give minor speedups (a few percent at best), the primary
# intent is to simplify the functions using this, by eliminating the need for
# them to craft their own ad-hoc caching schemes.
def compute_powers(w, base, more_than, show=False):
# If an algorithm wants the powers of ceiling(w/2) instead of the floor,
# pass keyword argument `need_hi=True`.
#
# While this does give minor speedups (a few percent at best), the
# primary intent is to simplify the functions using this, by eliminating
# the need for them to craft their own ad-hoc caching schemes.
#
# See code near end of file for a block of code that can be enabled to
# run millions of tests.
def compute_powers(w, base, more_than, *, need_hi=False, show=False):
seen = set()
need = set()
ws = {w}
@ -58,40 +64,70 @@ def compute_powers(w, base, more_than, show=False):
continue
seen.add(w)
lo = w >> 1
# only _need_ lo here; some other path may, or may not, need hi
need.add(lo)
ws.add(lo)
if w & 1:
ws.add(lo + 1)
hi = w - lo
# only _need_ one here; the other may, or may not, be needed
which = hi if need_hi else lo
need.add(which)
ws.add(which)
if lo != hi:
ws.add(w - which)
# `need` is the set of exponents needed. To compute them all
# efficiently, possibly add other exponents to `extra`. The goal is
# to ensure that each exponent can be gotten from a smaller one via
# multiplying by the base, squaring it, or squaring and then
# multiplying by the base.
#
# If need_hi is False, this is already the case (w can always be
# gotten from w >> 1 via one of the squaring strategies). But we do
# the work anyway, just in case ;-)
#
# Note that speed is irrelevant. These loops are working on little
# ints (exponents) and go around O(log w) times. The total cost is
# insignificant compared to just one of the bigint multiplies.
cands = need.copy()
extra = set()
while cands:
w = max(cands)
cands.remove(w)
lo = w >> 1
if lo > more_than and w-1 not in cands and lo not in cands:
extra.add(lo)
cands.add(lo)
assert need_hi or not extra
d = {}
if not need:
return d
it = iter(sorted(need))
first = next(it)
if show:
print("pow at", first)
d[first] = base ** first
for this in it:
if this - 1 in d:
for n in sorted(need | extra):
lo = n >> 1
hi = n - lo
if n-1 in d:
if show:
print("* base at", this)
d[this] = d[this - 1] * base # cheap
else:
lo = this >> 1
hi = this - lo
assert lo in d
print("* base", end="")
result = d[n-1] * base # cheap!
elif lo in d:
# Multiplying a bigint by itself is about twice as fast
# in CPython provided it's the same object.
if show:
print("square at", this)
# Multiplying a bigint by itself (same object!) is about twice
# as fast in CPython.
sq = d[lo] * d[lo]
print("square", end="")
result = d[lo] * d[lo] # same object
if hi != lo:
assert hi == lo + 1
if show:
print(" and * base")
sq *= base
d[this] = sq
print(" * base", end="")
assert 2 * lo + 1 == n
result *= base
else: # rare
if show:
print("pow", end='')
result = base ** n
if show:
print(" at", n, "needed" if n in need else "extra")
d[n] = result
assert need <= d.keys()
if excess := d.keys() - need:
assert need_hi
for n in excess:
del d[n]
return d
_unbounded_dec_context = decimal.getcontext().copy()
@ -211,6 +247,145 @@ def _str_to_int_inner(s):
return inner(0, len(s))
# Asymptotically faster version, using the C decimal module. See
# comments at the end of the file. This uses decimal arithmetic to
# convert from base 10 to base 256. The latter is just a string of
# bytes, which CPython can convert very efficiently to a Python int.
# log of 10 to base 256 with best-possible 53-bit precision. Obtained
# via:
# from mpmath import mp
# mp.prec = 1000
# print(float(mp.log(10, 256)).hex())
_LOG_10_BASE_256 = float.fromhex('0x1.a934f0979a371p-2') # about 0.415
# _spread is for internal testing. It maps a key to the number of times
# that condition obtained in _dec_str_to_int_inner:
# key 0 - quotient guess was right
# key 1 - quotient had to be boosted by 1, one time
# key 999 - one adjustment wasn't enough, so fell back to divmod
from collections import defaultdict
_spread = defaultdict(int)
del defaultdict
def _dec_str_to_int_inner(s, *, GUARD=8):
# Yes, BYTELIM is "large". Large enough that CPython will usually
# use the Karatsuba _str_to_int_inner to convert the string. This
# allowed reducing the cutoff for calling _this_ function from 3.5M
# to 2M digits. We could almost certainly do even better by
# fine-tuning this and/or using a larger output base than 256.
BYTELIM = 100_000
D = decimal.Decimal
result = bytearray()
# See notes at end of file for discussion of GUARD.
assert GUARD > 0 # if 0, `decimal` can blow up - .prec 0 not allowed
def inner(n, w):
#assert n < D256 ** w # required, but too expensive to check
if w <= BYTELIM:
# XXX Stefan Pochmann discovered that, for 1024-bit ints,
# `int(Decimal)` took 2.5x longer than `int(str(Decimal))`.
# Worse, `int(Decimal) is still quadratic-time for much
# larger ints. So unless/until all that is repaired, the
# seemingly redundant `str(Decimal)` is crucial to speed.
result.extend(int(str(n)).to_bytes(w)) # big-endian default
return
w1 = w >> 1
w2 = w - w1
if 0:
# This is maximally clear, but "too slow". `decimal`
# division is asymptotically fast, but we have no way to
# tell it to reuse the high-precision reciprocal it computes
# for pow256[w2], so it has to recompute it over & over &
# over again :-(
hi, lo = divmod(n, pow256[w2][0])
else:
p256, recip = pow256[w2]
# The integer part will have a number of digits about equal
# to the difference between the log10s of `n` and `pow256`
# (which, since these are integers, is roughly approximated
# by `.adjusted()`). That's the working precision we need,
ctx.prec = max(n.adjusted() - p256.adjusted(), 0) + GUARD
hi = +n * +recip # unary `+` chops back to ctx.prec digits
ctx.prec = decimal.MAX_PREC
hi = hi.to_integral_value() # lose the fractional digits
lo = n - hi * p256
# Because we've been uniformly rounding down, `hi` is a
# lower bound on the correct quotient.
assert lo >= 0
# Adjust quotient up if needed. It usually isn't. In random
# testing on inputs through 5 billion digit strings, the
# test triggered once in about 200 thousand tries.
count = 0
if lo >= p256:
count = 1
lo -= p256
hi += 1
if lo >= p256:
# Complete correction via an exact computation. I
# believe it's not possible to get here provided
# GUARD >= 3. It's tested by reducing GUARD below
# that.
count = 999
hi2, lo = divmod(lo, p256)
hi += hi2
_spread[count] += 1
# The assert should always succeed, but way too slow to keep
# enabled.
#assert hi, lo == divmod(n, pow256[w2][0])
inner(hi, w1)
del hi # at top levels, can free a lot of RAM "early"
inner(lo, w2)
# How many base 256 digits are needed?. Mathematically, exactly
# floor(log256(int(s))) + 1. There is no cheap way to compute this.
# But we can get an upper bound, and that's necessary for our error
# analysis to make sense. int(s) < 10**len(s), so the log needed is
# < log256(10**len(s)) = len(s) * log256(10). However, using
# finite-precision floating point for this, it's possible that the
# computed value is a little less than the true value. If the true
# value is at - or a little higher than - an integer, we can get an
# off-by-1 error too low. So we add 2 instead of 1 if chopping lost
# a fraction > 0.9.
# The "WASI" test platfrom can complain about `len(s)` if it's too
# large to fit in its idea of "an index-sized integer".
lenS = s.__len__()
log_ub = lenS * _LOG_10_BASE_256
log_ub_as_int = int(log_ub)
w = log_ub_as_int + 1 + (log_ub - log_ub_as_int > 0.9)
# And what if we've plain exhausted the limits of HW floats? We
# could compute the log to any desired precision using `decimal`,
# but it's not plausible that anyone will pass a string requiring
# trillions of bytes (unless they're just trying to "break things").
if w.bit_length() >= 46:
# "Only" had < 53 - 46 = 7 bits to spare in IEEE-754 double.
raise ValueError(f"cannot convert string of len {lenS} to int")
with decimal.localcontext(_unbounded_dec_context) as ctx:
D256 = D(256)
pow256 = compute_powers(w, D256, BYTELIM, need_hi=True)
rpow256 = compute_powers(w, 1 / D256, BYTELIM, need_hi=True)
# We're going to do inexact, chopped arithmetic, multiplying by
# an approximation to the reciprocal of 256**i. We chop to get a
# lower bound on the true integer quotient. Our approximation is
# a lower bound, the multiplication is chopped too, and
# to_integral_value() is also chopped.
ctx.traps[decimal.Inexact] = 0
ctx.rounding = decimal.ROUND_DOWN
for k, v in pow256.items():
# No need to save much more precision in the reciprocal than
# the power of 256 has, plus some guard digits to absorb
# most relevant rounding errors. This is highly significant:
# 1/2**i has the same number of significant decimal digits
# as 5**i, generally over twice the number in 2**i,
ctx.prec = v.adjusted() + GUARD + 1
# The unary "+" chops the reciprocal back to that precision.
pow256[k] = v, +rpow256[k]
del rpow256 # exact reciprocals no longer needed
ctx.prec = decimal.MAX_PREC
inner(D(s), w)
return int.from_bytes(result)
def int_from_string(s):
"""Asymptotically fast version of PyLong_FromString(), conversion
of a string of decimal digits into an 'int'."""
@ -219,7 +394,10 @@ def int_from_string(s):
# and underscores, and stripped leading whitespace. The input can still
# contain underscores and have trailing whitespace.
s = s.rstrip().replace('_', '')
return _str_to_int_inner(s)
func = _str_to_int_inner
if len(s) >= 2_000_000 and _decimal is not None:
func = _dec_str_to_int_inner
return func(s)
def str_to_int(s):
"""Asymptotically fast version of decimal string to 'int' conversion."""
@ -361,3 +539,191 @@ def int_divmod(a, b):
return ~q, b + ~r
else:
return _divmod_pos(a, b)
# Notes on _dec_str_to_int_inner:
#
# Stefan Pochmann worked up a str->int function that used the decimal
# module to, in effect, convert from base 10 to base 256. This is
# "unnatural", in that it requires multiplying and dividing by large
# powers of 2, which `decimal` isn't naturally suited to. But
# `decimal`'s `*` and `/` are asymptotically superior to CPython's, so
# at _some_ point it could be expected to win.
#
# Alas, the crossover point was too high to be of much real interest. I
# (Tim) then worked on ways to replace its division with multiplication
# by a cached reciprocal approximation instead, fixing up errors
# afterwards. This reduced the crossover point significantly,
#
# I revisited the code, and found ways to improve and simplify it. The
# crossover point is at about 3.4 million digits now.
#
# About .adjusted()
# -----------------
# Restrict to Decimal values x > 0. We don't use negative numbers in the
# code, and I don't want to have to keep typing, e.g., "absolute value".
#
# For convenience, I'll use `x.a` to mean `x.adjusted()`. x.a doesn't
# look at the digits of x, but instead returns an integer giving x's
# order of magnitude. These are equivalent:
#
# - x.a is the power-of-10 exponent of x's most significant digit.
# - x.a = the infinitely precise floor(log10(x))
# - x can be written in this form, where f is a real with 1 <= f < 10:
# x = f * 10**x.a
#
# Observation; if x is an integer, len(str(x)) = x.a + 1.
#
# Lemma 1: (x * y).a = x.a + y.a, or one larger
#
# Proof: Write x = f * 10**x.a and y = g * 10**y.a, where f and g are in
# [1, 10). Then x*y = f*g * 10**(x.a + y.a), where 1 <= f*g < 100. If
# f*g < 10, (x*y).a is x.a+y.a. Else divide f*g by 10 to bring it back
# into [1, 10], and add 1 to the exponent to compensate. Then (x*y).a is
# x.a+y.a+1.
#
# Lemma 2: ceiling(log10(x/y)) <= x.a - y.a + 1
#
# Proof: Express x and y as in Lemma 1. Then x/y = f/g * 10**(x.a -
# y.a), where 1/10 < f/g < 10. If 1 <= f/g, (x/y).a is x.a-y.a. Else
# multiply f/g by 10 to bring it back into [1, 10], and subtract 1 from
# the exponent to compensate. Then (x/y).a is x.a-y.a-1. So the largest
# (x/y).a can be is x.a-y.a. Since that's the floor of log10(x/y). the
# ceiling is at most 1 larger (with equality iff f/g = 1 exactly).
#
# GUARD digits
# ------------
# We only want the integer part of divisions, so don't need to build
# the full multiplication tree. But using _just_ the number of
# digits expected in the integer part ignores too much. What's left
# out can have a very significant effect on the quotient. So we use
# GUARD additional digits.
#
# The default 8 is more than enough so no more than 1 correction step
# was ever needed for all inputs tried through 2.5 billion digits. In
# fact, I believe 3 guard digits are always enough - but the proof is
# very involved, so better safe than sorry.
#
# Short course:
#
# If prec is the decimal precision in effect, and we're rounding down,
# the result of an operation is exactly equal to the infinitely precise
# result times 1-e for some real e with 0 <= e < 10**(1-prec). In
#
# ctx.prec = max(n.adjusted() - p256.adjusted(), 0) + GUARD
# hi = +n * +recip # unary `+` chops to ctx.prec digits
#
# we have 3 visible chopped operationa, but there's also a 4th:
# precomputing a truncated `recip` as part of setup.
#
# So the computed product is exactly equal to the true product times
# (1-e1)*(1-e2)*(1-e3)*(1-e4); since the e's are all very small, an
# excellent approximation to the second factor is 1-(e1+e2+e3+e4) (the
# 2nd and higher order terms in the expanded product are too tiny to
# matter). If they're all as large as possible, that's
#
# 1 - 4*10**(1-prec). This, BTW, is all bog-standard FP error analysis.
#
# That implies the computed product is within 1 of the true product
# provided prec >= log10(true_product) + 1.602.
#
# Here are telegraphic details, rephrasing the initial condition in
# equivalent ways, step by step:
#
# prod - prod * (1 - 4*10**(1-prec)) <= 1
# prod - prod + prod * 4*10**(1-prec)) <= 1
# prod * 4*10**(1-prec)) <= 1
# 10**(log10(prod)) * 4*10**(1-prec)) <= 1
# 4*10**(1-prec+log10(prod))) <= 1
# 10**(1-prec+log10(prod))) <= 1/4
# 1-prec+log10(prod) <= log10(1/4) = -0.602
# -prec <= -1.602 - log10(prod)
# prec >= log10(prod) + 1.602
#
# The true product is the same as the true ratio n/p256. By Lemma 2
# above, n.a - p256.a + 1 is an upper bound on the ceiling of
# log10(prod). Then 2 is the ceiling of 1.602. so n.a - p256.a + 3 is an
# upper bound on the right hand side of the inequality. Any prec >= that
# will work.
#
# But since this is just a sketch of a proof ;-), the code uses the
# empirically tested 8 instead of 3. 5 digits more or less makes no
# practical difference to speed - these ints are huge. And while
# increasing GUARD above 3 may not be necessary, every increase cuts the
# percentage of cases that need a correction at all.
#
# On Computing Reciprocals
# ------------------------
# In general, the exact reciprocals we compute have over twice as many
# significant digits as needed. 1/256**i has the same number of
# significant decimal digits as 5**i. It's a significant waste of RAM
# to store all those unneeded digits.
#
# So we cut exact reciprocals back to the least precision that can
# be needed so that the error analysis above is valid,
#
# [Note: turns out it's very significantly faster to do it this way than
# to compute 1 / 256**i directly to the desired precision, because the
# power method doesn't require division. It's also faster than computing
# (1/256)**i directly to the desired precision - no material division
# there, but `compute_powers()` is much smarter about _how_ to compute
# all the powers needed than repeated applications of `**` - that
# function invokes `**` for at most the few smallest powers needed.]
#
# The hard part is that chopping back to a shorter width occurs
# _outside_ of `inner`. We can't know then what `prec` `inner()` will
# need. We have to pick, for each value of `w2`, the largest possible
# value `prec` can become when `inner()` is working on `w2`.
#
# This is the `prec` inner() uses:
# max(n.a - p256.a, 0) + GUARD
# and what setup uses (renaming its `v` to `p256` - same thing):
# p256.a + GUARD + 1
#
# We need that the second is always at least as large as the first,
# which is the same as requiring
#
# n.a - 2 * p256.a <= 1
#
# What's the largest n can be? n < 255**w = 256**(w2 + (w - w2)). The
# worst case in this context is when w ix even. and then w = 2*w2, so
# n < 256**(2*w2) = (256**w2)**2 = p256**2. By Lemma 1, then, n.a
# is at most p256.a + p256.a + 1.
#
# So the most n.a - 2 * p256.a can be is
# p256.a + p256.a + 1 - 2 * p256.a = 1. QED
#
# Note: an earlier version of the code split on floor(e/2) instead of on
# the ceiling. The worst case then is odd `w`, and a more involved proof
# was needed to show that adding 4 (instead of 1) may be necessary.
# Basically because, in that case, n may be up to 256 times larger than
# p256**2. Curiously enough, by splitting on the ceiling instead,
# nothing in any proof here actually depends on the output base (256).
# Enable for brute-force testing of compute_powers(). This takes about a
# minute, because it tries millions of cases.
if 0:
def consumer(w, limir, need_hi):
seen = set()
need = set()
def inner(w):
if w <= limit:
return
if w in seen:
return
seen.add(w)
lo = w >> 1
hi = w - lo
need.add(hi if need_hi else lo)
inner(lo)
inner(hi)
inner(w)
exp = compute_powers(w, 1, limir, need_hi=need_hi)
assert exp.keys() == need
from itertools import chain
for need_hi in (False, True):
for limit in (0, 1, 10, 100, 1_000, 10_000, 100_000):
for w in chain(range(1, 100_000),
(10**i for i in range(5, 30))):
consumer(w, limit, need_hi)

View File

@ -919,5 +919,84 @@ class PyLongModuleTests(unittest.TestCase):
self.assertEqual(n, int(sn))
bits <<= 1
@support.requires_resource('cpu')
def test_pylong_roundtrip_huge(self):
# k blocks of 1234567890
k = 1_000_000 # so 10 million digits in all
tentoten = 10**10
n = 1234567890 * ((tentoten**k - 1) // (tentoten - 1))
sn = "1234567890" * k
self.assertEqual(n, int(sn))
self.assertEqual(sn, str(n))
@support.requires_resource('cpu')
@unittest.skipUnless(_pylong, "_pylong module required")
def test_whitebox_dec_str_to_int_inner_failsafe(self):
# While I believe the number of GUARD digits in this function is
# always enough so that no more than one correction step is ever
# needed, the code has a "failsafe" path that takes over if I'm
# wrong about that. We have no input that reaches that block.
# Here we test a contrived input that _does_ reach that block,
# provided the number of guard digits is reduced to 1.
sn = "9" * 2000156
n = 10**len(sn) - 1
orig_spread = _pylong._spread.copy()
_pylong._spread.clear()
try:
self.assertEqual(n, _pylong._dec_str_to_int_inner(sn, GUARD=1))
self.assertIn(999, _pylong._spread)
finally:
_pylong._spread.clear()
_pylong._spread.update(orig_spread)
@unittest.skipUnless(_pylong, "pylong module required")
def test_whitebox_dec_str_to_int_inner_monster(self):
# I don't think anyone has enough RAM to build a string long enough
# for this function to complain. So lie about the string length.
class LyingStr(str):
def __len__(self):
return int((1 << 47) / _pylong._LOG_10_BASE_256)
liar = LyingStr("42")
# We have to pass the liar directly to the complaining function. If we
# just try `int(liar)`, earlier layers will replace it with plain old
# "43".
# Embedding `len(liar)` into the f-string failed on the WASI testbot
# (don't know what that is):
# OverflowError: cannot fit 'int' into an index-sized integer
# So a random stab at worming around that.
self.assertRaisesRegex(ValueError,
f"^cannot convert string of len {liar.__len__()} to int$",
_pylong._dec_str_to_int_inner,
liar)
@unittest.skipUnless(_pylong, "_pylong module required")
def test_pylong_compute_powers(self):
# Basic sanity tests. See end of _pylong.py for manual heavy tests.
def consumer(w, base, limit, need_hi):
seen = set()
need = set()
def inner(w):
if w <= limit or w in seen:
return
seen.add(w)
lo = w >> 1
hi = w - lo
need.add(hi if need_hi else lo)
inner(lo)
inner(hi)
inner(w)
d = _pylong.compute_powers(w, base, limit, need_hi=need_hi)
self.assertEqual(d.keys(), need)
for k, v in d.items():
self.assertEqual(v, base ** k)
for base in 2, 5:
for need_hi in False, True:
for limit in 1, 11:
for w in range(250, 550):
consumer(w, base, limit, need_hi)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1 @@
If the C version of the ``decimal`` module is available, ``int(str)`` now uses it to supply an asymptotically much faster conversion. However, this only applies if the string contains over about 2 million digits.