diff --git a/Objects/longobject.c b/Objects/longobject.c index 4a37e5e14e4..fab7deee7df 100644 --- a/Objects/longobject.c +++ b/Objects/longobject.c @@ -775,6 +775,57 @@ convert_binop(PyObject *v, PyObject *w, PyLongObject **a, PyLongObject **b) { return Py_NotImplemented; \ } +/* x[0:m] and y[0:n] are digit vectors, LSD first, m >= n required. x[0:n] + * is modified in place, by adding y to it. Carries are propagated as far as + * x[m-1], and the remaining carry (0 or 1) is returned. + */ +static digit +v_iadd(digit *x, int m, digit *y, int n) +{ + int i; + digit carry = 0; + + assert(m >= n); + for (i = 0; i < n; ++i) { + carry += x[i] + y[i]; + x[i] = carry & MASK; + carry >>= SHIFT; + assert((carry & 1) == carry); + } + for (; carry && i < m; ++i) { + carry += x[i]; + x[i] = carry & MASK; + carry >>= SHIFT; + assert((carry & 1) == carry); + } + return carry; +} + +/* x[0:m] and y[0:n] are digit vectors, LSD first, m >= n required. x[0:n] + * is modified in place, by subtracting y from it. Borrows are propagated as + * far as x[m-1], and the remaining borrow (0 or 1) is returned. + */ +static digit +v_isub(digit *x, int m, digit *y, int n) +{ + int i; + digit borrow = 0; + + assert(m >= n); + for (i = 0; i < n; ++i) { + borrow = x[i] - y[i] - borrow; + x[i] = borrow & MASK; + borrow >>= SHIFT; + borrow &= 1; /* keep only 1 sign bit */ + } + for (; borrow && i < m; ++i) { + borrow = x[i] - borrow; + x[i] = borrow & MASK; + borrow >>= SHIFT; + borrow &= 1; + } + return borrow; +} /* Multiply by a single digit, ignoring the sign. */ @@ -1558,7 +1609,9 @@ k_mul(PyLongObject *a, PyLongObject *b) PyLongObject *t1, *t2; int shift; /* the number of digits we split off */ int i; - +#ifdef Py_DEBUG + digit d; +#endif /* (ah*X+al)(bh*X+bl) = ah*bh*X*X + (ah*bl + al*bh)*X + al*bl * Let k = (ah+al)*(bh+bl) = ah*bl + al*bh + ah*bh + al*bl * Then the original product is @@ -1629,39 +1682,32 @@ k_mul(PyLongObject *a, PyLongObject *b) Py_DECREF(t2); if (k == NULL) goto fail; - /* Subtract ahbh and albl from k. Note that this can't become - * negative, since k = ahbh + albl + other stuff. + /* Add k into the result, starting at the shift'th LSD. */ + i = ret->ob_size - shift; /* # digits after shift */ +#ifdef Py_DEBUG + d = +#endif + v_iadd(ret->ob_digit + shift, i, k->ob_digit, k->ob_size); + assert(d == 0); + Py_DECREF(k); + + /* Subtract ahbh and albl from the result. Note that this can't + * become negative, since k = ahbh + albl + other stuff. */ - if ((t1 = x_sub(k, ahbh)) == NULL) goto fail; - Py_DECREF(k); - k = t1; +#ifdef Py_DEBUG + d = +#endif + v_isub(ret->ob_digit + shift, i, ahbh->ob_digit, ahbh->ob_size); + assert(d == 0); Py_DECREF(ahbh); - ahbh = NULL; - if ((t1 = x_sub(k, albl)) == NULL) goto fail; - Py_DECREF(k); - k = t1; +#ifdef Py_DEBUG + d = +#endif + v_isub(ret->ob_digit + shift, i, albl->ob_digit, albl->ob_size); + assert(d == 0); Py_DECREF(albl); - albl = NULL; - /* Add k into the result, at the shift-th least-significant digit. */ - { - int j; /* index into k */ - digit carry = 0; - - for (i = shift, j = 0; j < k->ob_size; ++i, ++j) { - carry += ret->ob_digit[i] + k->ob_digit[j]; - ret->ob_digit[i] = carry & MASK; - carry >>= SHIFT; - } - for (; carry && i < ret->ob_size; ++i) { - carry += ret->ob_digit[i]; - ret->ob_digit[i] = carry & MASK; - carry >>= SHIFT; - } - assert(carry == 0); - } - Py_DECREF(k); return long_normalize(ret); fail: