gh-90716: Refactor PyLong_FromString to separate concerns (GH-96808)

This is a preliminary PR to refactor `PyLong_FromString` which is currently quite messy and has spaghetti like code that mixes up different concerns as well as duplicating logic.

In particular:

- `PyLong_FromString` now only handles sign, base and prefix detection and calls a new function `long_from_string_base` to parse the main body of the string.
- The `long_from_string_base` function handles all string validation and then calls `long_from_binary_base` or a new function `long_from_non_binary_base` to construct the actual `PyLong`.
- The existing `long_from_binary_base` function is simplified by factoring duplicated logic to `long_from_string_base`.
- The new function `long_from_non_binary_base` factors out much of the code from `PyLong_FromString` including in particular the quadratic algorithm reffered to in gh-95778 so that this can be seen separately from unrelated concerns such as string validation.
This commit is contained in:
Oscar Benjamin 2022-09-25 10:09:50 +01:00 committed by GitHub
parent ea4be278fa
commit 817fa28f81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 302 additions and 261 deletions

View File

@ -0,0 +1,2 @@
The ``PyLong_FromString`` function was refactored to make it more maintainable
and extensible.

View File

@ -2193,23 +2193,23 @@ unsigned char _PyLong_DigitValue[256] = {
37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,
};
/* *str points to the first digit in a string of base `base` digits. base
* is a power of 2 (2, 4, 8, 16, or 32). *str is set to point to the first
* non-digit (which may be *str!). A normalized int is returned.
* The point to this routine is that it takes time linear in the number of
* string characters.
/* `start` and `end` point to the start and end of a string of base `base`
* digits. base is a power of 2 (2, 4, 8, 16, or 32). An unnormalized int is
* returned in *res. The string should be already validated by the caller and
* consists only of valid digit characters and underscores. `digits` gives the
* number of digit characters.
*
* The point to this routine is that it takes time linear in the
* number of string characters.
*
* Return values:
* -1 on syntax error (exception needs to be set, *res is untouched)
* 0 else (exception may be set, in that case *res is set to NULL)
*/
static int
long_from_binary_base(const char **str, int base, PyLongObject **res)
long_from_binary_base(const char *start, const char *end, Py_ssize_t digits, int base, PyLongObject **res)
{
const char *p = *str;
const char *start = p;
char prev = 0;
Py_ssize_t digits = 0;
const char *p;
int bits_per_char;
Py_ssize_t n;
PyLongObject *z;
@ -2222,26 +2222,7 @@ long_from_binary_base(const char **str, int base, PyLongObject **res)
for (bits_per_char = -1; n; ++bits_per_char) {
n >>= 1;
}
/* count digits and set p to end-of-string */
while (_PyLong_DigitValue[Py_CHARMASK(*p)] < base || *p == '_') {
if (*p == '_') {
if (prev == '_') {
*str = p - 1;
return -1;
}
} else {
++digits;
}
prev = *p;
++p;
}
if (prev == '_') {
/* Trailing underscore not allowed. */
*str = p - 1;
return -1;
}
*str = p;
/* n <- the number of Python digits needed,
= ceiling((digits * bits_per_char) / PyLong_SHIFT). */
if (digits > (PY_SSIZE_T_MAX - (PyLong_SHIFT - 1)) / bits_per_char) {
@ -2262,6 +2243,7 @@ long_from_binary_base(const char **str, int base, PyLongObject **res)
accum = 0;
bits_in_accum = 0;
pdigit = z->ob_digit;
p = end;
while (--p >= start) {
int k;
if (*p == '_') {
@ -2286,88 +2268,14 @@ long_from_binary_base(const char **str, int base, PyLongObject **res)
}
while (pdigit - z->ob_digit < n)
*pdigit++ = 0;
*res = long_normalize(z);
*res = z;
return 0;
}
/* Parses an int from a bytestring. Leading and trailing whitespace will be
* ignored.
*
* If successful, a PyLong object will be returned and 'pend' will be pointing
* to the first unused byte unless it's NULL.
*
* If unsuccessful, NULL will be returned.
*/
PyObject *
PyLong_FromString(const char *str, char **pend, int base)
{
int sign = 1, error_if_nonzero = 0;
const char *start, *orig_str = str;
PyLongObject *z = NULL;
PyObject *strobj;
Py_ssize_t slen;
if ((base != 0 && base < 2) || base > 36) {
PyErr_SetString(PyExc_ValueError,
"int() arg 2 must be >= 2 and <= 36");
return NULL;
}
while (*str != '\0' && Py_ISSPACE(*str)) {
str++;
}
if (*str == '+') {
++str;
}
else if (*str == '-') {
++str;
sign = -1;
}
if (base == 0) {
if (str[0] != '0') {
base = 10;
}
else if (str[1] == 'x' || str[1] == 'X') {
base = 16;
}
else if (str[1] == 'o' || str[1] == 'O') {
base = 8;
}
else if (str[1] == 'b' || str[1] == 'B') {
base = 2;
}
else {
/* "old" (C-style) octal literal, now invalid.
it might still be zero though */
error_if_nonzero = 1;
base = 10;
}
}
if (str[0] == '0' &&
((base == 16 && (str[1] == 'x' || str[1] == 'X')) ||
(base == 8 && (str[1] == 'o' || str[1] == 'O')) ||
(base == 2 && (str[1] == 'b' || str[1] == 'B')))) {
str += 2;
/* One underscore allowed here. */
if (*str == '_') {
++str;
}
}
if (str[0] == '_') {
/* May not start with underscores. */
goto onError;
}
start = str;
if ((base & (base - 1)) == 0) {
/* binary bases are not limited by int_max_str_digits */
int res = long_from_binary_base(&str, base, &z);
if (res < 0) {
/* Syntax error. */
goto onError;
}
}
else {
/***
long_from_non_binary_base: parameters and return values are the same as
long_from_binary_base.
Binary bases can be converted in time linear in the number of digits, because
Python's representation base is binary. Other bases (including decimal!) use
the simple quadratic-time algorithm below, complicated by some speed tricks.
@ -2452,15 +2360,17 @@ that triggers it(!). Instead the code was tested by artificially allocating
just 1 digit at the start, so that the copying code was exercised for every
digit beyond the first.
***/
static int
long_from_non_binary_base(const char *start, const char *end, Py_ssize_t digits, int base, PyLongObject **res)
{
twodigits c; /* current input character */
Py_ssize_t size_z;
Py_ssize_t digits = 0;
int i;
int convwidth;
twodigits convmultmax, convmult;
digit *pz, *pzstop;
const char *scan, *lastdigit;
char prev = 0;
PyLongObject *z;
const char *p;
static double log_base_BASE[37] = {0.0e0,};
static int convwidth_base[37] = {0,};
@ -2485,43 +2395,6 @@ digit beyond the first.
convwidth_base[base] = i;
}
/* Find length of the string of numeric characters. */
scan = str;
lastdigit = str;
while (_PyLong_DigitValue[Py_CHARMASK(*scan)] < base || *scan == '_') {
if (*scan == '_') {
if (prev == '_') {
/* Only one underscore allowed. */
str = lastdigit + 1;
goto onError;
}
}
else {
++digits;
lastdigit = scan;
}
prev = *scan;
++scan;
}
if (prev == '_') {
/* Trailing underscore not allowed. */
/* Set error pointer to first underscore. */
str = lastdigit + 1;
goto onError;
}
/* Limit the size to avoid excessive computation attacks. */
if (digits > _PY_LONG_MAX_STR_DIGITS_THRESHOLD) {
PyInterpreterState *interp = _PyInterpreterState_GET();
int max_str_digits = interp->int_max_str_digits;
if ((max_str_digits > 0) && (digits > max_str_digits)) {
PyErr_Format(PyExc_ValueError, _MAX_STR_DIGITS_ERROR_FMT_TO_INT,
max_str_digits, digits);
return NULL;
}
}
/* Create an int object that can contain the largest possible
* integer with this base and length. Note that there's no
* need to initialize z->ob_digit -- no slot is read up before
@ -2532,7 +2405,8 @@ digit beyond the first.
/* The same exception as in _PyLong_New(). */
PyErr_SetString(PyExc_OverflowError,
"too many digits in integer");
return NULL;
*res = NULL;
return 0;
}
size_z = (Py_ssize_t)fsize_z;
/* Uncomment next line to test exceedingly rare copy code */
@ -2540,7 +2414,8 @@ digit beyond the first.
assert(size_z > 0);
z = _PyLong_New(size_z);
if (z == NULL) {
return NULL;
*res = NULL;
return 0;
}
Py_SET_SIZE(z, 0);
@ -2551,20 +2426,21 @@ digit beyond the first.
convmultmax = convmultmax_base[base];
/* Work ;-) */
while (str < scan) {
if (*str == '_') {
str++;
p = start;
while (p < end) {
if (*p == '_') {
p++;
continue;
}
/* grab up to convwidth digits from the input string */
c = (digit)_PyLong_DigitValue[Py_CHARMASK(*str++)];
for (i = 1; i < convwidth && str != scan; ++str) {
if (*str == '_') {
c = (digit)_PyLong_DigitValue[Py_CHARMASK(*p++)];
for (i = 1; i < convwidth && p != end; ++p) {
if (*p == '_') {
continue;
}
i++;
c = (twodigits)(c * base +
(int)_PyLong_DigitValue[Py_CHARMASK(*str)]);
(int)_PyLong_DigitValue[Py_CHARMASK(*p)]);
assert(c < PyLong_BASE);
}
@ -2601,7 +2477,8 @@ digit beyond the first.
tmp = _PyLong_New(size_z + 1);
if (tmp == NULL) {
Py_DECREF(z);
return NULL;
*res = NULL;
return 0;
}
memcpy(tmp->ob_digit,
z->ob_digit,
@ -2613,10 +2490,181 @@ digit beyond the first.
}
}
}
*res = z;
return 0;
}
/* *str points to the first digit in a string of base `base` digits. base is an
* integer from 2 to 36 inclusive. Here we don't need to worry about prefixes
* like 0x or leading +- signs. The string should be null terminated consisting
* of ASCII digits and separating underscores possibly with trailing whitespace
* but we have to validate all of those points here.
*
* If base is a power of 2 then the complexity is linear in the number of
* characters in the string. Otherwise a quadratic algorithm is used for
* non-binary bases.
*
* Return values:
*
* - Returns -1 on syntax error (exception needs to be set, *res is untouched)
* - Returns 0 and sets *res to NULL for MemoryError/OverflowError.
* - Returns 0 and sets *res to an unsigned, unnormalized PyLong (success!).
*
* Afterwards *str is set to point to the first non-digit (which may be *str!).
*/
static int
long_from_string_base(const char **str, int base, PyLongObject **res)
{
const char *start, *end, *p;
char prev = 0;
Py_ssize_t digits = 0;
int is_binary_base = (base & (base - 1)) == 0;
/* Here we do four things:
*
* - Find the `end` of the string.
* - Validate the string.
* - Count the number of `digits` (rather than underscores)
* - Point *str to the end-of-string or first invalid character.
*/
start = p = *str;
/* Leading underscore not allowed. */
if (*start == '_') {
return -1;
}
if (z == NULL) {
/* Verify all characters are digits and underscores. */
while (_PyLong_DigitValue[Py_CHARMASK(*p)] < base || *p == '_') {
if (*p == '_') {
/* Double underscore not allowed. */
if (prev == '_') {
*str = p - 1;
return -1;
}
} else {
++digits;
}
prev = *p;
++p;
}
/* Trailing underscore not allowed. */
if (prev == '_') {
*str = p - 1;
return -1;
}
*str = end = p;
/* Reject empty strings */
if (start == end) {
return -1;
}
/* Allow only trailing whitespace after `end` */
while (*p && Py_ISSPACE(*p)) {
p++;
}
*str = p;
if (*p != '\0') {
return -1;
}
/*
* Pass a validated string consisting of only valid digits and underscores
* to long_from_xxx_base.
*/
if (is_binary_base) {
/* Use the linear algorithm for binary bases. */
return long_from_binary_base(start, end, digits, base, res);
}
else {
/* Limit the size to avoid excessive computation attacks exploiting the
* quadratic algorithm. */
if (digits > _PY_LONG_MAX_STR_DIGITS_THRESHOLD) {
PyInterpreterState *interp = _PyInterpreterState_GET();
int max_str_digits = interp->int_max_str_digits;
if ((max_str_digits > 0) && (digits > max_str_digits)) {
PyErr_Format(PyExc_ValueError, _MAX_STR_DIGITS_ERROR_FMT_TO_INT,
max_str_digits, digits);
*res = NULL;
return 0;
}
}
/* Use the quadratic algorithm for non binary bases. */
return long_from_non_binary_base(start, end, digits, base, res);
}
}
/* Parses an int from a bytestring. Leading and trailing whitespace will be
* ignored.
*
* If successful, a PyLong object will be returned and 'pend' will be pointing
* to the first unused byte unless it's NULL.
*
* If unsuccessful, NULL will be returned.
*/
PyObject *
PyLong_FromString(const char *str, char **pend, int base)
{
int sign = 1, error_if_nonzero = 0;
const char *orig_str = str;
PyLongObject *z = NULL;
PyObject *strobj;
Py_ssize_t slen;
if ((base != 0 && base < 2) || base > 36) {
PyErr_SetString(PyExc_ValueError,
"int() arg 2 must be >= 2 and <= 36");
return NULL;
}
while (*str != '\0' && Py_ISSPACE(*str)) {
++str;
}
if (*str == '+') {
++str;
}
else if (*str == '-') {
++str;
sign = -1;
}
if (base == 0) {
if (str[0] != '0') {
base = 10;
}
else if (str[1] == 'x' || str[1] == 'X') {
base = 16;
}
else if (str[1] == 'o' || str[1] == 'O') {
base = 8;
}
else if (str[1] == 'b' || str[1] == 'B') {
base = 2;
}
else {
/* "old" (C-style) octal literal, now invalid.
it might still be zero though */
error_if_nonzero = 1;
base = 10;
}
}
if (str[0] == '0' &&
((base == 16 && (str[1] == 'x' || str[1] == 'X')) ||
(base == 8 && (str[1] == 'o' || str[1] == 'O')) ||
(base == 2 && (str[1] == 'b' || str[1] == 'B')))) {
str += 2;
/* One underscore allowed here. */
if (*str == '_') {
++str;
}
}
/* long_from_string_base is the main workhorse here. */
int ret = long_from_string_base(&str, base, &z);
if (ret == -1) {
/* Syntax error. */
goto onError;
}
if (z == NULL) {
/* Error. exception already set. */
return NULL;
}
if (error_if_nonzero) {
/* reset the base to 0, else the exception message
doesn't make too much sense */
@ -2627,23 +2675,14 @@ digit beyond the first.
/* there might still be other problems, therefore base
remains zero here for the same reason */
}
if (str == start) {
goto onError;
}
/* Set sign and normalize */
if (sign < 0) {
Py_SET_SIZE(z, -(Py_SIZE(z)));
}
while (*str && Py_ISSPACE(*str)) {
str++;
}
if (*str != '\0') {
goto onError;
}
long_normalize(z);
z = maybe_small_long(z);
if (z == NULL) {
return NULL;
}
if (pend != NULL) {
*pend = (char *)str;
}