gh-90716: Refactor PyLong_FromString to separate concerns (GH-96808)

This is a preliminary PR to refactor `PyLong_FromString` which is currently quite messy and has spaghetti like code that mixes up different concerns as well as duplicating logic.

In particular:

- `PyLong_FromString` now only handles sign, base and prefix detection and calls a new function `long_from_string_base` to parse the main body of the string.
- The `long_from_string_base` function handles all string validation and then calls `long_from_binary_base` or a new function `long_from_non_binary_base` to construct the actual `PyLong`.
- The existing `long_from_binary_base` function is simplified by factoring duplicated logic to `long_from_string_base`.
- The new function `long_from_non_binary_base` factors out much of the code from `PyLong_FromString` including in particular the quadratic algorithm reffered to in gh-95778 so that this can be seen separately from unrelated concerns such as string validation.
This commit is contained in:
Oscar Benjamin 2022-09-25 10:09:50 +01:00 committed by GitHub
parent ea4be278fa
commit 817fa28f81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 302 additions and 261 deletions

View File

@ -0,0 +1,2 @@
The ``PyLong_FromString`` function was refactored to make it more maintainable
and extensible.

View File

@ -2193,23 +2193,23 @@ unsigned char _PyLong_DigitValue[256] = {
37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,
}; };
/* *str points to the first digit in a string of base `base` digits. base /* `start` and `end` point to the start and end of a string of base `base`
* is a power of 2 (2, 4, 8, 16, or 32). *str is set to point to the first * digits. base is a power of 2 (2, 4, 8, 16, or 32). An unnormalized int is
* non-digit (which may be *str!). A normalized int is returned. * returned in *res. The string should be already validated by the caller and
* The point to this routine is that it takes time linear in the number of * consists only of valid digit characters and underscores. `digits` gives the
* string characters. * number of digit characters.
*
* The point to this routine is that it takes time linear in the
* number of string characters.
* *
* Return values: * Return values:
* -1 on syntax error (exception needs to be set, *res is untouched) * -1 on syntax error (exception needs to be set, *res is untouched)
* 0 else (exception may be set, in that case *res is set to NULL) * 0 else (exception may be set, in that case *res is set to NULL)
*/ */
static int static int
long_from_binary_base(const char **str, int base, PyLongObject **res) long_from_binary_base(const char *start, const char *end, Py_ssize_t digits, int base, PyLongObject **res)
{ {
const char *p = *str; const char *p;
const char *start = p;
char prev = 0;
Py_ssize_t digits = 0;
int bits_per_char; int bits_per_char;
Py_ssize_t n; Py_ssize_t n;
PyLongObject *z; PyLongObject *z;
@ -2222,26 +2222,7 @@ long_from_binary_base(const char **str, int base, PyLongObject **res)
for (bits_per_char = -1; n; ++bits_per_char) { for (bits_per_char = -1; n; ++bits_per_char) {
n >>= 1; n >>= 1;
} }
/* count digits and set p to end-of-string */
while (_PyLong_DigitValue[Py_CHARMASK(*p)] < base || *p == '_') {
if (*p == '_') {
if (prev == '_') {
*str = p - 1;
return -1;
}
} else {
++digits;
}
prev = *p;
++p;
}
if (prev == '_') {
/* Trailing underscore not allowed. */
*str = p - 1;
return -1;
}
*str = p;
/* n <- the number of Python digits needed, /* n <- the number of Python digits needed,
= ceiling((digits * bits_per_char) / PyLong_SHIFT). */ = ceiling((digits * bits_per_char) / PyLong_SHIFT). */
if (digits > (PY_SSIZE_T_MAX - (PyLong_SHIFT - 1)) / bits_per_char) { if (digits > (PY_SSIZE_T_MAX - (PyLong_SHIFT - 1)) / bits_per_char) {
@ -2262,6 +2243,7 @@ long_from_binary_base(const char **str, int base, PyLongObject **res)
accum = 0; accum = 0;
bits_in_accum = 0; bits_in_accum = 0;
pdigit = z->ob_digit; pdigit = z->ob_digit;
p = end;
while (--p >= start) { while (--p >= start) {
int k; int k;
if (*p == '_') { if (*p == '_') {
@ -2286,88 +2268,14 @@ long_from_binary_base(const char **str, int base, PyLongObject **res)
} }
while (pdigit - z->ob_digit < n) while (pdigit - z->ob_digit < n)
*pdigit++ = 0; *pdigit++ = 0;
*res = long_normalize(z); *res = z;
return 0; return 0;
} }
/* Parses an int from a bytestring. Leading and trailing whitespace will be
* ignored.
*
* If successful, a PyLong object will be returned and 'pend' will be pointing
* to the first unused byte unless it's NULL.
*
* If unsuccessful, NULL will be returned.
*/
PyObject *
PyLong_FromString(const char *str, char **pend, int base)
{
int sign = 1, error_if_nonzero = 0;
const char *start, *orig_str = str;
PyLongObject *z = NULL;
PyObject *strobj;
Py_ssize_t slen;
if ((base != 0 && base < 2) || base > 36) {
PyErr_SetString(PyExc_ValueError,
"int() arg 2 must be >= 2 and <= 36");
return NULL;
}
while (*str != '\0' && Py_ISSPACE(*str)) {
str++;
}
if (*str == '+') {
++str;
}
else if (*str == '-') {
++str;
sign = -1;
}
if (base == 0) {
if (str[0] != '0') {
base = 10;
}
else if (str[1] == 'x' || str[1] == 'X') {
base = 16;
}
else if (str[1] == 'o' || str[1] == 'O') {
base = 8;
}
else if (str[1] == 'b' || str[1] == 'B') {
base = 2;
}
else {
/* "old" (C-style) octal literal, now invalid.
it might still be zero though */
error_if_nonzero = 1;
base = 10;
}
}
if (str[0] == '0' &&
((base == 16 && (str[1] == 'x' || str[1] == 'X')) ||
(base == 8 && (str[1] == 'o' || str[1] == 'O')) ||
(base == 2 && (str[1] == 'b' || str[1] == 'B')))) {
str += 2;
/* One underscore allowed here. */
if (*str == '_') {
++str;
}
}
if (str[0] == '_') {
/* May not start with underscores. */
goto onError;
}
start = str;
if ((base & (base - 1)) == 0) {
/* binary bases are not limited by int_max_str_digits */
int res = long_from_binary_base(&str, base, &z);
if (res < 0) {
/* Syntax error. */
goto onError;
}
}
else {
/*** /***
long_from_non_binary_base: parameters and return values are the same as
long_from_binary_base.
Binary bases can be converted in time linear in the number of digits, because Binary bases can be converted in time linear in the number of digits, because
Python's representation base is binary. Other bases (including decimal!) use Python's representation base is binary. Other bases (including decimal!) use
the simple quadratic-time algorithm below, complicated by some speed tricks. the simple quadratic-time algorithm below, complicated by some speed tricks.
@ -2452,171 +2360,311 @@ that triggers it(!). Instead the code was tested by artificially allocating
just 1 digit at the start, so that the copying code was exercised for every just 1 digit at the start, so that the copying code was exercised for every
digit beyond the first. digit beyond the first.
***/ ***/
twodigits c; /* current input character */ static int
Py_ssize_t size_z; long_from_non_binary_base(const char *start, const char *end, Py_ssize_t digits, int base, PyLongObject **res)
Py_ssize_t digits = 0; {
int i; twodigits c; /* current input character */
int convwidth; Py_ssize_t size_z;
twodigits convmultmax, convmult; int i;
digit *pz, *pzstop; int convwidth;
const char *scan, *lastdigit; twodigits convmultmax, convmult;
char prev = 0; digit *pz, *pzstop;
PyLongObject *z;
const char *p;
static double log_base_BASE[37] = {0.0e0,}; static double log_base_BASE[37] = {0.0e0,};
static int convwidth_base[37] = {0,}; static int convwidth_base[37] = {0,};
static twodigits convmultmax_base[37] = {0,}; static twodigits convmultmax_base[37] = {0,};
if (log_base_BASE[base] == 0.0) { if (log_base_BASE[base] == 0.0) {
twodigits convmax = base; twodigits convmax = base;
int i = 1; int i = 1;
log_base_BASE[base] = (log((double)base) / log_base_BASE[base] = (log((double)base) /
log((double)PyLong_BASE)); log((double)PyLong_BASE));
for (;;) { for (;;) {
twodigits next = convmax * base; twodigits next = convmax * base;
if (next > PyLong_BASE) { if (next > PyLong_BASE) {
break; break;
}
convmax = next;
++i;
} }
convmultmax_base[base] = convmax; convmax = next;
assert(i > 0); ++i;
convwidth_base[base] = i; }
convmultmax_base[base] = convmax;
assert(i > 0);
convwidth_base[base] = i;
}
/* Create an int object that can contain the largest possible
* integer with this base and length. Note that there's no
* need to initialize z->ob_digit -- no slot is read up before
* being stored into.
*/
double fsize_z = (double)digits * log_base_BASE[base] + 1.0;
if (fsize_z > (double)MAX_LONG_DIGITS) {
/* The same exception as in _PyLong_New(). */
PyErr_SetString(PyExc_OverflowError,
"too many digits in integer");
*res = NULL;
return 0;
}
size_z = (Py_ssize_t)fsize_z;
/* Uncomment next line to test exceedingly rare copy code */
/* size_z = 1; */
assert(size_z > 0);
z = _PyLong_New(size_z);
if (z == NULL) {
*res = NULL;
return 0;
}
Py_SET_SIZE(z, 0);
/* `convwidth` consecutive input digits are treated as a single
* digit in base `convmultmax`.
*/
convwidth = convwidth_base[base];
convmultmax = convmultmax_base[base];
/* Work ;-) */
p = start;
while (p < end) {
if (*p == '_') {
p++;
continue;
}
/* grab up to convwidth digits from the input string */
c = (digit)_PyLong_DigitValue[Py_CHARMASK(*p++)];
for (i = 1; i < convwidth && p != end; ++p) {
if (*p == '_') {
continue;
}
i++;
c = (twodigits)(c * base +
(int)_PyLong_DigitValue[Py_CHARMASK(*p)]);
assert(c < PyLong_BASE);
} }
/* Find length of the string of numeric characters. */ convmult = convmultmax;
scan = str; /* Calculate the shift only if we couldn't get
lastdigit = str; * convwidth digits.
*/
if (i != convwidth) {
convmult = base;
for ( ; i > 1; --i) {
convmult *= base;
}
}
while (_PyLong_DigitValue[Py_CHARMASK(*scan)] < base || *scan == '_') { /* Multiply z by convmult, and add c. */
if (*scan == '_') { pz = z->ob_digit;
if (prev == '_') { pzstop = pz + Py_SIZE(z);
/* Only one underscore allowed. */ for (; pz < pzstop; ++pz) {
str = lastdigit + 1; c += (twodigits)*pz * convmult;
goto onError; *pz = (digit)(c & PyLong_MASK);
} c >>= PyLong_SHIFT;
}
/* carry off the current end? */
if (c) {
assert(c < PyLong_BASE);
if (Py_SIZE(z) < size_z) {
*pz = (digit)c;
Py_SET_SIZE(z, Py_SIZE(z) + 1);
} }
else { else {
++digits; PyLongObject *tmp;
lastdigit = scan; /* Extremely rare. Get more space. */
assert(Py_SIZE(z) == size_z);
tmp = _PyLong_New(size_z + 1);
if (tmp == NULL) {
Py_DECREF(z);
*res = NULL;
return 0;
}
memcpy(tmp->ob_digit,
z->ob_digit,
sizeof(digit) * size_z);
Py_DECREF(z);
z = tmp;
z->ob_digit[size_z] = (digit)c;
++size_z;
} }
prev = *scan;
++scan;
}
if (prev == '_') {
/* Trailing underscore not allowed. */
/* Set error pointer to first underscore. */
str = lastdigit + 1;
goto onError;
} }
}
*res = z;
return 0;
}
/* Limit the size to avoid excessive computation attacks. */ /* *str points to the first digit in a string of base `base` digits. base is an
* integer from 2 to 36 inclusive. Here we don't need to worry about prefixes
* like 0x or leading +- signs. The string should be null terminated consisting
* of ASCII digits and separating underscores possibly with trailing whitespace
* but we have to validate all of those points here.
*
* If base is a power of 2 then the complexity is linear in the number of
* characters in the string. Otherwise a quadratic algorithm is used for
* non-binary bases.
*
* Return values:
*
* - Returns -1 on syntax error (exception needs to be set, *res is untouched)
* - Returns 0 and sets *res to NULL for MemoryError/OverflowError.
* - Returns 0 and sets *res to an unsigned, unnormalized PyLong (success!).
*
* Afterwards *str is set to point to the first non-digit (which may be *str!).
*/
static int
long_from_string_base(const char **str, int base, PyLongObject **res)
{
const char *start, *end, *p;
char prev = 0;
Py_ssize_t digits = 0;
int is_binary_base = (base & (base - 1)) == 0;
/* Here we do four things:
*
* - Find the `end` of the string.
* - Validate the string.
* - Count the number of `digits` (rather than underscores)
* - Point *str to the end-of-string or first invalid character.
*/
start = p = *str;
/* Leading underscore not allowed. */
if (*start == '_') {
return -1;
}
/* Verify all characters are digits and underscores. */
while (_PyLong_DigitValue[Py_CHARMASK(*p)] < base || *p == '_') {
if (*p == '_') {
/* Double underscore not allowed. */
if (prev == '_') {
*str = p - 1;
return -1;
}
} else {
++digits;
}
prev = *p;
++p;
}
/* Trailing underscore not allowed. */
if (prev == '_') {
*str = p - 1;
return -1;
}
*str = end = p;
/* Reject empty strings */
if (start == end) {
return -1;
}
/* Allow only trailing whitespace after `end` */
while (*p && Py_ISSPACE(*p)) {
p++;
}
*str = p;
if (*p != '\0') {
return -1;
}
/*
* Pass a validated string consisting of only valid digits and underscores
* to long_from_xxx_base.
*/
if (is_binary_base) {
/* Use the linear algorithm for binary bases. */
return long_from_binary_base(start, end, digits, base, res);
}
else {
/* Limit the size to avoid excessive computation attacks exploiting the
* quadratic algorithm. */
if (digits > _PY_LONG_MAX_STR_DIGITS_THRESHOLD) { if (digits > _PY_LONG_MAX_STR_DIGITS_THRESHOLD) {
PyInterpreterState *interp = _PyInterpreterState_GET(); PyInterpreterState *interp = _PyInterpreterState_GET();
int max_str_digits = interp->int_max_str_digits; int max_str_digits = interp->int_max_str_digits;
if ((max_str_digits > 0) && (digits > max_str_digits)) { if ((max_str_digits > 0) && (digits > max_str_digits)) {
PyErr_Format(PyExc_ValueError, _MAX_STR_DIGITS_ERROR_FMT_TO_INT, PyErr_Format(PyExc_ValueError, _MAX_STR_DIGITS_ERROR_FMT_TO_INT,
max_str_digits, digits); max_str_digits, digits);
return NULL; *res = NULL;
} return 0;
}
/* Create an int object that can contain the largest possible
* integer with this base and length. Note that there's no
* need to initialize z->ob_digit -- no slot is read up before
* being stored into.
*/
double fsize_z = (double)digits * log_base_BASE[base] + 1.0;
if (fsize_z > (double)MAX_LONG_DIGITS) {
/* The same exception as in _PyLong_New(). */
PyErr_SetString(PyExc_OverflowError,
"too many digits in integer");
return NULL;
}
size_z = (Py_ssize_t)fsize_z;
/* Uncomment next line to test exceedingly rare copy code */
/* size_z = 1; */
assert(size_z > 0);
z = _PyLong_New(size_z);
if (z == NULL) {
return NULL;
}
Py_SET_SIZE(z, 0);
/* `convwidth` consecutive input digits are treated as a single
* digit in base `convmultmax`.
*/
convwidth = convwidth_base[base];
convmultmax = convmultmax_base[base];
/* Work ;-) */
while (str < scan) {
if (*str == '_') {
str++;
continue;
}
/* grab up to convwidth digits from the input string */
c = (digit)_PyLong_DigitValue[Py_CHARMASK(*str++)];
for (i = 1; i < convwidth && str != scan; ++str) {
if (*str == '_') {
continue;
}
i++;
c = (twodigits)(c * base +
(int)_PyLong_DigitValue[Py_CHARMASK(*str)]);
assert(c < PyLong_BASE);
}
convmult = convmultmax;
/* Calculate the shift only if we couldn't get
* convwidth digits.
*/
if (i != convwidth) {
convmult = base;
for ( ; i > 1; --i) {
convmult *= base;
}
}
/* Multiply z by convmult, and add c. */
pz = z->ob_digit;
pzstop = pz + Py_SIZE(z);
for (; pz < pzstop; ++pz) {
c += (twodigits)*pz * convmult;
*pz = (digit)(c & PyLong_MASK);
c >>= PyLong_SHIFT;
}
/* carry off the current end? */
if (c) {
assert(c < PyLong_BASE);
if (Py_SIZE(z) < size_z) {
*pz = (digit)c;
Py_SET_SIZE(z, Py_SIZE(z) + 1);
}
else {
PyLongObject *tmp;
/* Extremely rare. Get more space. */
assert(Py_SIZE(z) == size_z);
tmp = _PyLong_New(size_z + 1);
if (tmp == NULL) {
Py_DECREF(z);
return NULL;
}
memcpy(tmp->ob_digit,
z->ob_digit,
sizeof(digit) * size_z);
Py_DECREF(z);
z = tmp;
z->ob_digit[size_z] = (digit)c;
++size_z;
}
} }
} }
/* Use the quadratic algorithm for non binary bases. */
return long_from_non_binary_base(start, end, digits, base, res);
} }
if (z == NULL) { }
/* Parses an int from a bytestring. Leading and trailing whitespace will be
* ignored.
*
* If successful, a PyLong object will be returned and 'pend' will be pointing
* to the first unused byte unless it's NULL.
*
* If unsuccessful, NULL will be returned.
*/
PyObject *
PyLong_FromString(const char *str, char **pend, int base)
{
int sign = 1, error_if_nonzero = 0;
const char *orig_str = str;
PyLongObject *z = NULL;
PyObject *strobj;
Py_ssize_t slen;
if ((base != 0 && base < 2) || base > 36) {
PyErr_SetString(PyExc_ValueError,
"int() arg 2 must be >= 2 and <= 36");
return NULL; return NULL;
} }
while (*str != '\0' && Py_ISSPACE(*str)) {
++str;
}
if (*str == '+') {
++str;
}
else if (*str == '-') {
++str;
sign = -1;
}
if (base == 0) {
if (str[0] != '0') {
base = 10;
}
else if (str[1] == 'x' || str[1] == 'X') {
base = 16;
}
else if (str[1] == 'o' || str[1] == 'O') {
base = 8;
}
else if (str[1] == 'b' || str[1] == 'B') {
base = 2;
}
else {
/* "old" (C-style) octal literal, now invalid.
it might still be zero though */
error_if_nonzero = 1;
base = 10;
}
}
if (str[0] == '0' &&
((base == 16 && (str[1] == 'x' || str[1] == 'X')) ||
(base == 8 && (str[1] == 'o' || str[1] == 'O')) ||
(base == 2 && (str[1] == 'b' || str[1] == 'B')))) {
str += 2;
/* One underscore allowed here. */
if (*str == '_') {
++str;
}
}
/* long_from_string_base is the main workhorse here. */
int ret = long_from_string_base(&str, base, &z);
if (ret == -1) {
/* Syntax error. */
goto onError;
}
if (z == NULL) {
/* Error. exception already set. */
return NULL;
}
if (error_if_nonzero) { if (error_if_nonzero) {
/* reset the base to 0, else the exception message /* reset the base to 0, else the exception message
doesn't make too much sense */ doesn't make too much sense */
@ -2627,23 +2675,14 @@ digit beyond the first.
/* there might still be other problems, therefore base /* there might still be other problems, therefore base
remains zero here for the same reason */ remains zero here for the same reason */
} }
if (str == start) {
goto onError; /* Set sign and normalize */
}
if (sign < 0) { if (sign < 0) {
Py_SET_SIZE(z, -(Py_SIZE(z))); Py_SET_SIZE(z, -(Py_SIZE(z)));
} }
while (*str && Py_ISSPACE(*str)) {
str++;
}
if (*str != '\0') {
goto onError;
}
long_normalize(z); long_normalize(z);
z = maybe_small_long(z); z = maybe_small_long(z);
if (z == NULL) {
return NULL;
}
if (pend != NULL) { if (pend != NULL) {
*pend = (char *)str; *pend = (char *)str;
} }