gh-90716: add _pylong.py module (#96673)

Add Python implementations of certain longobject.c functions. These use
asymptotically faster algorithms that can be used for operations on
integers with many digits. In those cases, the performance overhead of
the Python implementation is not significant since the asymptotic
behavior is what dominates runtime. Functions provided by this module
should be considered private and not part of any public API.

Co-author: Tim Peters <tim.peters@gmail.com>
Co-author: Mark Dickinson <dickinsm@gmail.com>
Co-author: Bjorn Martinsson
This commit is contained in:
Neil Schemenauer 2022-10-25 22:00:50 -07:00 committed by GitHub
parent 5d30544485
commit de6981680b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 572 additions and 0 deletions

295
Lib/_pylong.py Normal file
View File

@ -0,0 +1,295 @@
"""Python implementations of some algorithms for use by longobject.c.
The goal is to provide asymptotically faster algorithms that can be
used for operations on integers with many digits. In those cases, the
performance overhead of the Python implementation is not significant
since the asymptotic behavior is what dominates runtime. Functions
provided by this module should be considered private and not part of any
public API.
Note: for ease of maintainability, please prefer clear code and avoid
"micro-optimizations". This module will only be imported and used for
integers with a huge number of digits. Saving a few microseconds with
tricky or non-obvious code is not worth it. For people looking for
maximum performance, they should use something like gmpy2."""
import sys
import re
import decimal
_DEBUG = False
def int_to_decimal(n):
"""Asymptotically fast conversion of an 'int' to Decimal."""
# Function due to Tim Peters. See GH issue #90716 for details.
# https://github.com/python/cpython/issues/90716
#
# The implementation in longobject.c of base conversion algorithms
# between power-of-2 and non-power-of-2 bases are quadratic time.
# This function implements a divide-and-conquer algorithm that is
# faster for large numbers. Builds an equal decimal.Decimal in a
# "clever" recursive way. If we want a string representation, we
# apply str to _that_.
if _DEBUG:
print('int_to_decimal', n.bit_length(), file=sys.stderr)
D = decimal.Decimal
D2 = D(2)
BITLIM = 128
mem = {}
def w2pow(w):
"""Return D(2)**w and store the result. Also possibly save some
intermediate results. In context, these are likely to be reused
across various levels of the conversion to Decimal."""
if (result := mem.get(w)) is None:
if w <= BITLIM:
result = D2**w
elif w - 1 in mem:
result = (t := mem[w - 1]) + t
else:
w2 = w >> 1
# If w happens to be odd, w-w2 is one larger then w2
# now. Recurse on the smaller first (w2), so that it's
# in the cache and the larger (w-w2) can be handled by
# the cheaper `w-1 in mem` branch instead.
result = w2pow(w2) * w2pow(w - w2)
mem[w] = result
return result
def inner(n, w):
if w <= BITLIM:
return D(n)
w2 = w >> 1
hi = n >> w2
lo = n - (hi << w2)
return inner(lo, w2) + inner(hi, w - w2) * w2pow(w2)
with decimal.localcontext() as ctx:
ctx.prec = decimal.MAX_PREC
ctx.Emax = decimal.MAX_EMAX
ctx.Emin = decimal.MIN_EMIN
ctx.traps[decimal.Inexact] = 1
if n < 0:
negate = True
n = -n
else:
negate = False
result = inner(n, n.bit_length())
if negate:
result = -result
return result
def int_to_decimal_string(n):
"""Asymptotically fast conversion of an 'int' to a decimal string."""
return str(int_to_decimal(n))
def _str_to_int_inner(s):
"""Asymptotically fast conversion of a 'str' to an 'int'."""
# Function due to Bjorn Martinsson. See GH issue #90716 for details.
# https://github.com/python/cpython/issues/90716
#
# The implementation in longobject.c of base conversion algorithms
# between power-of-2 and non-power-of-2 bases are quadratic time.
# This function implements a divide-and-conquer algorithm making use
# of Python's built in big int multiplication. Since Python uses the
# Karatsuba algorithm for multiplication, the time complexity
# of this function is O(len(s)**1.58).
DIGLIM = 2048
mem = {}
def w5pow(w):
"""Return 5**w and store the result.
Also possibly save some intermediate results. In context, these
are likely to be reused across various levels of the conversion
to 'int'.
"""
if (result := mem.get(w)) is None:
if w <= DIGLIM:
result = 5**w
elif w - 1 in mem:
result = mem[w - 1] * 5
else:
w2 = w >> 1
# If w happens to be odd, w-w2 is one larger then w2
# now. Recurse on the smaller first (w2), so that it's
# in the cache and the larger (w-w2) can be handled by
# the cheaper `w-1 in mem` branch instead.
result = w5pow(w2) * w5pow(w - w2)
mem[w] = result
return result
def inner(a, b):
if b - a <= DIGLIM:
return int(s[a:b])
mid = (a + b + 1) >> 1
return inner(mid, b) + ((inner(a, mid) * w5pow(b - mid)) << (b - mid))
return inner(0, len(s))
def int_from_string(s):
"""Asymptotically fast version of PyLong_FromString(), conversion
of a string of decimal digits into an 'int'."""
if _DEBUG:
print('int_from_string', len(s), file=sys.stderr)
# PyLong_FromString() has already removed leading +/-, checked for invalid
# use of underscore characters, checked that string consists of only digits
# and underscores, and stripped leading whitespace. The input can still
# contain underscores and have trailing whitespace.
s = s.rstrip().replace('_', '')
return _str_to_int_inner(s)
def str_to_int(s):
"""Asymptotically fast version of decimal string to 'int' conversion."""
# FIXME: this doesn't support the full syntax that int() supports.
m = re.match(r'\s*([+-]?)([0-9_]+)\s*', s)
if not m:
raise ValueError('invalid literal for int() with base 10')
v = int_from_string(m.group(2))
if m.group(1) == '-':
v = -v
return v
# Fast integer division, based on code from Mark Dickinson, fast_div.py
# GH-47701. Additional refinements and optimizations by Bjorn Martinsson. The
# algorithm is due to Burnikel and Ziegler, in their paper "Fast Recursive
# Division".
_DIV_LIMIT = 4000
def _div2n1n(a, b, n):
"""Divide a 2n-bit nonnegative integer a by an n-bit positive integer
b, using a recursive divide-and-conquer algorithm.
Inputs:
n is a positive integer
b is a positive integer with exactly n bits
a is a nonnegative integer such that a < 2**n * b
Output:
(q, r) such that a = b*q+r and 0 <= r < b.
"""
if a.bit_length() - n <= _DIV_LIMIT:
return divmod(a, b)
pad = n & 1
if pad:
a <<= 1
b <<= 1
n += 1
half_n = n >> 1
mask = (1 << half_n) - 1
b1, b2 = b >> half_n, b & mask
q1, r = _div3n2n(a >> n, (a >> half_n) & mask, b, b1, b2, half_n)
q2, r = _div3n2n(r, a & mask, b, b1, b2, half_n)
if pad:
r >>= 1
return q1 << half_n | q2, r
def _div3n2n(a12, a3, b, b1, b2, n):
"""Helper function for _div2n1n; not intended to be called directly."""
if a12 >> n == b1:
q, r = (1 << n) - 1, a12 - (b1 << n) + b1
else:
q, r = _div2n1n(a12, b1, n)
r = (r << n | a3) - q * b2
while r < 0:
q -= 1
r += b
return q, r
def _int2digits(a, n):
"""Decompose non-negative int a into base 2**n
Input:
a is a non-negative integer
Output:
List of the digits of a in base 2**n in little-endian order,
meaning the most significant digit is last. The most
significant digit is guaranteed to be non-zero.
If a is 0 then the output is an empty list.
"""
a_digits = [0] * ((a.bit_length() + n - 1) // n)
def inner(x, L, R):
if L + 1 == R:
a_digits[L] = x
return
mid = (L + R) >> 1
shift = (mid - L) * n
upper = x >> shift
lower = x ^ (upper << shift)
inner(lower, L, mid)
inner(upper, mid, R)
if a:
inner(a, 0, len(a_digits))
return a_digits
def _digits2int(digits, n):
"""Combine base-2**n digits into an int. This function is the
inverse of `_int2digits`. For more details, see _int2digits.
"""
def inner(L, R):
if L + 1 == R:
return digits[L]
mid = (L + R) >> 1
shift = (mid - L) * n
return (inner(mid, R) << shift) + inner(L, mid)
return inner(0, len(digits)) if digits else 0
def _divmod_pos(a, b):
"""Divide a non-negative integer a by a positive integer b, giving
quotient and remainder."""
# Use grade-school algorithm in base 2**n, n = nbits(b)
n = b.bit_length()
a_digits = _int2digits(a, n)
r = 0
q_digits = []
for a_digit in reversed(a_digits):
q_digit, r = _div2n1n((r << n) + a_digit, b, n)
q_digits.append(q_digit)
q_digits.reverse()
q = _digits2int(q_digits, n)
return q, r
def int_divmod(a, b):
"""Asymptotically fast replacement for divmod, for 'int'.
Its time complexity is O(n**1.58), where n = #bits(a) + #bits(b).
"""
if _DEBUG:
print('int_divmod', a.bit_length(), b.bit_length(), file=sys.stderr)
if b == 0:
raise ZeroDivisionError
elif b < 0:
q, r = int_divmod(-a, -b)
return q, -r
elif a < 0:
q, r = int_divmod(~a, b)
return ~q, b + ~r
else:
return _divmod_pos(a, b)

View File

@ -795,5 +795,52 @@ class IntSubclassStrDigitLimitsTests(IntStrDigitLimitsTests):
int_class = IntSubclass
class PyLongModuleTests(unittest.TestCase):
# Tests of the functions in _pylong.py. Those get used when the
# number of digits in the input values are large enough.
def setUp(self):
super().setUp()
self._previous_limit = sys.get_int_max_str_digits()
sys.set_int_max_str_digits(0)
def tearDown(self):
sys.set_int_max_str_digits(self._previous_limit)
super().tearDown()
def test_pylong_int_to_decimal(self):
n = (1 << 100_000) - 1
suffix = '9883109375'
s = str(n)
assert s[-10:] == suffix
s = str(-n)
assert s[-10:] == suffix
s = '%d' % n
assert s[-10:] == suffix
s = b'%d' % n
assert s[-10:] == suffix.encode('ascii')
def test_pylong_int_divmod(self):
n = (1 << 100_000)
a, b = divmod(n*3 + 1, n)
assert a == 3 and b == 1
def test_pylong_str_to_int(self):
v1 = 1 << 100_000
s = str(v1)
v2 = int(s)
assert v1 == v2
v3 = int(' -' + s)
assert -v1 == v3
v4 = int(' +' + s + ' ')
assert v1 == v4
with self.assertRaises(ValueError) as err:
int(s + 'z')
with self.assertRaises(ValueError) as err:
int(s + '_')
with self.assertRaises(ValueError) as err:
int('_' + s)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,3 @@
Add _pylong.py module. It includes asymptotically faster algorithms that
can be used for operations on integers with many digits. It is used by
longobject.c to speed up some operations.

View File

@ -39,6 +39,9 @@ medium_value(PyLongObject *x)
#define _MAX_STR_DIGITS_ERROR_FMT_TO_INT "Exceeds the limit (%d) for integer string conversion: value has %zd digits; use sys.set_int_max_str_digits() to increase the limit"
#define _MAX_STR_DIGITS_ERROR_FMT_TO_STR "Exceeds the limit (%d) for integer string conversion; use sys.set_int_max_str_digits() to increase the limit"
/* If defined, use algorithms from the _pylong.py module */
#define WITH_PYLONG_MODULE 1
static inline void
_Py_DECREF_INT(PyLongObject *op)
{
@ -1732,6 +1735,69 @@ rem1(PyLongObject *a, digit n)
);
}
#ifdef WITH_PYLONG_MODULE
/* asymptotically faster long_to_decimal_string, using _pylong.py */
static int
pylong_int_to_decimal_string(PyObject *aa,
PyObject **p_output,
_PyUnicodeWriter *writer,
_PyBytesWriter *bytes_writer,
char **bytes_str)
{
PyObject *s = NULL;
PyObject *mod = PyImport_ImportModule("_pylong");
if (mod == NULL) {
return -1;
}
s = PyObject_CallMethod(mod, "int_to_decimal_string", "O", aa);
if (s == NULL) {
goto error;
}
assert(PyUnicode_Check(s));
if (writer) {
Py_ssize_t size = PyUnicode_GET_LENGTH(s);
if (_PyUnicodeWriter_Prepare(writer, size, '9') == -1) {
goto error;
}
if (_PyUnicodeWriter_WriteStr(writer, s) < 0) {
goto error;
}
goto success;
}
else if (bytes_writer) {
Py_ssize_t size = PyUnicode_GET_LENGTH(s);
const void *data = PyUnicode_DATA(s);
int kind = PyUnicode_KIND(s);
*bytes_str = _PyBytesWriter_Prepare(bytes_writer, *bytes_str, size);
if (*bytes_str == NULL) {
goto error;
}
char *p = *bytes_str;
for (Py_ssize_t i=0; i < size; i++) {
Py_UCS4 ch = PyUnicode_READ(kind, data, i);
*p++ = (char) ch;
}
(*bytes_str) = p;
goto success;
}
else {
*p_output = (PyObject *)s;
Py_INCREF(s);
goto success;
}
error:
Py_DECREF(mod);
Py_XDECREF(s);
return -1;
success:
Py_DECREF(mod);
Py_DECREF(s);
return 0;
}
#endif /* WITH_PYLONG_MODULE */
/* Convert an integer to a base 10 string. Returns a new non-shared
string. (Return value is non-shared so that callers can modify the
returned value if necessary.) */
@ -1776,6 +1842,17 @@ long_to_decimal_string_internal(PyObject *aa,
}
}
#if WITH_PYLONG_MODULE
if (size_a > 1000) {
/* Switch to _pylong.int_to_decimal_string(). */
return pylong_int_to_decimal_string(aa,
p_output,
writer,
bytes_writer,
bytes_str);
}
#endif
/* quick and dirty upper bound for the number of digits
required to express a in base _PyLong_DECIMAL_BASE:
@ -2272,6 +2349,39 @@ long_from_binary_base(const char *start, const char *end, Py_ssize_t digits, int
return 0;
}
static PyObject *long_neg(PyLongObject *v);
#ifdef WITH_PYLONG_MODULE
/* asymptotically faster str-to-long conversion for base 10, using _pylong.py */
static int
pylong_int_from_string(const char *start, const char *end, PyLongObject **res)
{
PyObject *mod = PyImport_ImportModule("_pylong");
if (mod == NULL) {
goto error;
}
PyObject *s = PyUnicode_FromStringAndSize(start, end-start);
if (s == NULL) {
goto error;
}
PyObject *result = PyObject_CallMethod(mod, "int_from_string", "O", s);
Py_DECREF(s);
Py_DECREF(mod);
if (result == NULL) {
goto error;
}
if (!PyLong_Check(result)) {
PyErr_SetString(PyExc_TypeError, "an integer is required");
goto error;
}
*res = (PyLongObject *)result;
return 0;
error:
*res = NULL;
return 0;
}
#endif /* WITH_PYLONG_MODULE */
/***
long_from_non_binary_base: parameters and return values are the same as
long_from_binary_base.
@ -2586,6 +2696,12 @@ long_from_string_base(const char **str, int base, PyLongObject **res)
return 0;
}
}
#if WITH_PYLONG_MODULE
if (digits > 6000 && base == 10) {
/* Switch to _pylong.int_from_string() */
return pylong_int_from_string(start, end, res);
}
#endif
/* Use the quadratic algorithm for non binary bases. */
return long_from_non_binary_base(start, end, digits, base, res);
}
@ -3913,6 +4029,48 @@ fast_floor_div(PyLongObject *a, PyLongObject *b)
return PyLong_FromLong(div);
}
#ifdef WITH_PYLONG_MODULE
/* asymptotically faster divmod, using _pylong.py */
static int
pylong_int_divmod(PyLongObject *v, PyLongObject *w,
PyLongObject **pdiv, PyLongObject **pmod)
{
PyObject *mod = PyImport_ImportModule("_pylong");
if (mod == NULL) {
return -1;
}
PyObject *result = PyObject_CallMethod(mod, "int_divmod", "OO", v, w);
Py_DECREF(mod);
if (result == NULL) {
return -1;
}
if (!PyTuple_Check(result)) {
Py_DECREF(result);
PyErr_SetString(PyExc_ValueError,
"tuple is required from int_divmod()");
return -1;
}
PyObject *q = PyTuple_GET_ITEM(result, 0);
PyObject *r = PyTuple_GET_ITEM(result, 1);
if (!PyLong_Check(q) || !PyLong_Check(r)) {
Py_DECREF(result);
PyErr_SetString(PyExc_ValueError,
"tuple of int is required from int_divmod()");
return -1;
}
if (pdiv != NULL) {
Py_INCREF(q);
*pdiv = (PyLongObject *)q;
}
if (pmod != NULL) {
Py_INCREF(r);
*pmod = (PyLongObject *)r;
}
Py_DECREF(result);
return 0;
}
#endif /* WITH_PYLONG_MODULE */
/* The / and % operators are now defined in terms of divmod().
The expression a mod b has the value a - b*floor(a/b).
The long_divrem function gives the remainder after division of
@ -3964,6 +4122,18 @@ l_divmod(PyLongObject *v, PyLongObject *w,
}
return 0;
}
#if WITH_PYLONG_MODULE
Py_ssize_t size_v = Py_ABS(Py_SIZE(v)); /* digits in numerator */
Py_ssize_t size_w = Py_ABS(Py_SIZE(w)); /* digits in denominator */
if (size_w > 300 && (size_v - size_w) > 150) {
/* Switch to _pylong.int_divmod(). If the quotient is small then
"schoolbook" division is linear-time so don't use in that case.
These limits are empirically determined and should be slightly
conservative so that _pylong is used in cases it is likely
to be faster. See Tools/scripts/divmod_threshold.py. */
return pylong_int_divmod(v, w, pdiv, pmod);
}
#endif
if (long_divrem(v, w, &div, &mod) < 0)
return -1;
if ((Py_SIZE(mod) < 0 && Py_SIZE(w) > 0) ||

View File

@ -58,6 +58,7 @@ static const char* _Py_stdlib_module_names[] = {
"_py_abc",
"_pydecimal",
"_pyio",
"_pylong",
"_queue",
"_random",
"_scproxy",

View File

@ -0,0 +1,56 @@
#!/usr/bin/env python3
#
# Determine threshold for switching from longobject.c divmod to
# _pylong.int_divmod().
from random import randrange
from time import perf_counter as now
from _pylong import int_divmod as divmod_fast
BITS_PER_DIGIT = 30
def rand_digits(n):
top = 1 << (n * BITS_PER_DIGIT)
return randrange(top >> 1, top)
def probe_den(nd):
den = rand_digits(nd)
count = 0
for nn in range(nd, nd + 3000):
num = rand_digits(nn)
t0 = now()
e1, e2 = divmod(num, den)
t1 = now()
f1, f2 = divmod_fast(num, den)
t2 = now()
s1 = t1 - t0
s2 = t2 - t1
assert e1 == f1
assert e2 == f2
if s2 < s1:
count += 1
if count >= 3:
print(
"for",
nd,
"denom digits,",
nn - nd,
"extra num digits is enough",
)
break
else:
count = 0
else:
print("for", nd, "denom digits, no num seems big enough")
def main():
for nd in range(30):
nd = (nd + 1) * 100
probe_den(nd)
if __name__ == '__main__':
main()