Cautious introduction of a patch that started from

SF 560379:  Karatsuba multiplication.
Lots of things were changed from that.  This needs a lot more testing,
for correctness and speed, the latter especially when bit lengths are
unbalanced.  For now, the Karatsuba code gets invoked if and only if
envar KARAT exists.
This commit is contained in:
Tim Peters 2002-08-12 02:31:19 +00:00
parent 5f7617b5f6
commit 5af4e6c739
3 changed files with 272 additions and 88 deletions

View File

@ -106,6 +106,7 @@ Greg Couch
Steve Cousins
Alex Coventry
Matthew Dixon Cowles
Christopher A. Craig
Drew Csillag
Tom Culliton
John Cugini

View File

@ -6,6 +6,10 @@ Type/class unification and new-style classes
Core and builtins
- XXX Karatsuba multiplication. This is currently used if and only
if envar KARAT exists. It needs more correctness and speed testing,
the latter especially with unbalanced bit lengths.
- u'%c' will now raise a ValueError in case the argument is an
integer outside the valid range of Unicode code point ordinals.
@ -83,7 +87,7 @@ Core and builtins
as directory names.
- The built-ins slice() and buffer() are now callable types. The
types classobj (formerly class), code, function, instance, and
0 types classobj (formerly class), code, function, instance, and
instancemethod (formerly instance-method), which have no built-in
names but are accessible through the types module, are now also
callable. The type dict-proxy is renamed to dictproxy.

View File

@ -8,8 +8,19 @@
#include <ctype.h>
/* For long multiplication, use the O(N**2) school algorithm unless
* both operands contain more than KARATSUBA_CUTOFF digits (this
* being an internal Python long digit, in base BASE).
*/
#define KARATSUBA_CUTOFF 35
#define ABS(x) ((x) < 0 ? -(x) : (x))
#undef MIN
#undef MAX
#define MAX(x, y) ((x) < (y) ? (y) : (x))
#define MIN(x, y) ((x) > (y) ? (y) : (x))
/* Forward */
static PyLongObject *long_normalize(PyLongObject *);
static PyLongObject *mul1(PyLongObject *, wdigit);
@ -1457,54 +1468,28 @@ long_repeat(PyObject *v, PyLongObject *w)
return (*v->ob_type->tp_as_sequence->sq_repeat)(v, n);
}
static PyObject *
long_mul(PyLongObject *v, PyLongObject *w)
/* Grade school multiplication, ignoring the signs.
* Returns the absolute value of the product, or NULL if error.
*/
static PyLongObject *
x_mul(PyLongObject *a, PyLongObject *b)
{
PyLongObject *a, *b, *z;
int size_a;
int size_b;
PyLongObject *z;
int size_a = ABS(a->ob_size);
int size_b = ABS(b->ob_size);
int i;
if (!convert_binop((PyObject *)v, (PyObject *)w, &a, &b)) {
if (!PyLong_Check(v) &&
v->ob_type->tp_as_sequence &&
v->ob_type->tp_as_sequence->sq_repeat)
return long_repeat((PyObject *)v, w);
if (!PyLong_Check(w) &&
w->ob_type->tp_as_sequence &&
w->ob_type->tp_as_sequence->sq_repeat)
return long_repeat((PyObject *)w, v);
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
size_a = ABS(a->ob_size);
size_b = ABS(b->ob_size);
if (size_a > size_b) {
/* we are faster with the small object on the left */
int hold_sa = size_a;
PyLongObject *hold_a = a;
size_a = size_b;
size_b = hold_sa;
a = b;
b = hold_a;
}
z = _PyLong_New(size_a + size_b);
if (z == NULL) {
Py_DECREF(a);
Py_DECREF(b);
if (z == NULL)
return NULL;
}
for (i = 0; i < z->ob_size; ++i)
z->ob_digit[i] = 0;
memset(z->ob_digit, 0, z->ob_size * sizeof(digit));
for (i = 0; i < size_a; ++i) {
twodigits carry = 0;
twodigits f = a->ob_digit[i];
int j;
SIGCHECK({
Py_DECREF(a);
Py_DECREF(b);
Py_DECREF(z);
return NULL;
})
@ -1520,6 +1505,205 @@ long_mul(PyLongObject *v, PyLongObject *w)
carry >>= SHIFT;
}
}
return z;
}
/* A helper for Karatsuba multiplication (k_mul).
Takes a long "n" and an integer "size" representing the place to
split, and sets low and high such that abs(n) == (high << size) + low,
viewing the shift as being by digits. The sign bit is ignored, and
the return values are >= 0.
Returns 0 on success, -1 on failure.
*/
static int
kmul_split(PyLongObject *n, int size, PyLongObject **high, PyLongObject **low)
{
PyLongObject *hi, *lo;
int size_lo, size_hi;
const int size_n = ABS(n->ob_size);
size_lo = MIN(size_n, size);
size_hi = size_n - size_lo;
if ((hi = _PyLong_New(size_hi)) == NULL)
return -1;
if ((lo = _PyLong_New(size_lo)) == NULL) {
Py_DECREF(hi);
return -1;
}
memcpy(lo->ob_digit, n->ob_digit, size_lo * sizeof(digit));
memcpy(hi->ob_digit, n->ob_digit + size_lo, size_hi * sizeof(digit));
*high = long_normalize(hi);
*low = long_normalize(lo);
return 0;
}
/* Karatsuba multiplication. Ignores the input signs, and returns the
* absolute value of the product (or NULL if error).
* See Knuth Vol. 2 Chapter 4.3.3 (Pp. 294-295).
*/
static PyLongObject *
k_mul(PyLongObject *a, PyLongObject *b)
{
PyLongObject *ah = NULL;
PyLongObject *al = NULL;
PyLongObject *bh = NULL;
PyLongObject *bl = NULL;
PyLongObject *albl = NULL;
PyLongObject *ahbh = NULL;
PyLongObject *k = NULL;
PyLongObject *ret = NULL;
PyLongObject *t1, *t2;
int shift; /* the number of digits we split off */
int i;
/* (ah*X+al)(bh*X+bl) = ah*bh*X*X + (ah*bl + al*bh)*X + al*bl
* Let k = (ah+al)*(bh+bl) = ah*bl + al*bh + ah*bh + al*bl
* Then the original product is
* ah*bh*X*X + (k - ah*bh - ah*bl)*X + al*bl
* By picking X to be a power of 2, "*X" is just shifting, and it's
* been reduced to 3 multiplies on numbers half the size.
*/
/* We want to split based on the larger number; fiddle so that a
* is largest.
*/
if (ABS(a->ob_size) > ABS(b->ob_size)) {
t1 = a;
a = b;
b = t1;
}
/* Use gradeschool math when either number is too small. */
if (ABS(a->ob_size) <= KARATSUBA_CUTOFF)
return x_mul(a, b);
shift = ABS(b->ob_size) >> 1;
if (kmul_split(a, shift, &ah, &al) < 0) goto fail;
if (kmul_split(b, shift, &bh, &bl) < 0) goto fail;
if ((ahbh = k_mul(ah, bh)) == NULL) goto fail;
assert(ahbh->ob_size >= 0);
/* Allocate result space, and copy ahbh into the high digits. */
ret = _PyLong_New(ahbh->ob_size + 2*shift + 1);
if (ret == NULL) goto fail;
#ifdef Py_DEBUG
/* Fill with trash, to catch reference to uninitialized digits. */
memset(ret->ob_digit, 0xDF, ret->ob_size * sizeof(digit));
#endif
memcpy(ret->ob_digit + 2*shift, ahbh->ob_digit,
ahbh->ob_size * sizeof(digit));
/* That didn't copy into the most-significant (overflow) digit. */
ret->ob_digit[ret->ob_size - 1] = 0;
/* Compute al*bl, and copy into the low digits. */
if ((albl = k_mul(al, bl)) == NULL) goto fail;
assert(albl->ob_size >= 0);
assert(albl->ob_size <= 2*shift); /* no overlap with high digits */
memcpy(ret->ob_digit, albl->ob_digit, albl->ob_size * sizeof(digit));
/* Zero out remaining digits. */
i = 2*shift - albl->ob_size; /* number of uninitialized digits */
if (i)
memset(ret->ob_digit + albl->ob_size, 0, i * sizeof(digit));
/* k = (ah+al)(bh+bl) */
if ((t1 = x_add(ah, al)) == NULL) goto fail;
Py_DECREF(ah);
Py_DECREF(al);
ah = al = NULL;
if ((t2 = x_add(bh, bl)) == NULL) {
Py_DECREF(t1);
goto fail;
}
Py_DECREF(bh);
Py_DECREF(bl);
bh = bl = NULL;
k = k_mul(t1, t2);
Py_DECREF(t1);
Py_DECREF(t2);
if (k == NULL) goto fail;
/* Subtract ahbh and albl from k. Note that this can't become
* negative, since k = ahbh + albl + other stuff.
*/
if ((t1 = x_sub(k, ahbh)) == NULL) goto fail;
Py_DECREF(k);
k = t1;
Py_DECREF(ahbh);
ahbh = NULL;
if ((t1 = x_sub(k, albl)) == NULL) goto fail;
Py_DECREF(k);
k = t1;
Py_DECREF(albl);
albl = NULL;
/* Add k into the result, at the shift-th least-significant digit. */
{
int j; /* index into k */
digit carry = 0;
for (i = shift, j = 0; j < k->ob_size; ++i, ++j) {
carry += ret->ob_digit[i] + k->ob_digit[j];
ret->ob_digit[i] = carry & MASK;
carry >>= SHIFT;
}
for (; carry && i < ret->ob_size; ++i) {
carry += ret->ob_digit[i];
ret->ob_digit[i] = carry & MASK;
carry >>= SHIFT;
}
assert(carry == 0);
}
Py_DECREF(k);
return long_normalize(ret);
fail:
Py_XDECREF(ret);
Py_XDECREF(ah);
Py_XDECREF(al);
Py_XDECREF(bh);
Py_XDECREF(bl);
Py_XDECREF(ahbh);
Py_XDECREF(albl);
Py_XDECREF(k);
return NULL;
}
static PyObject *
long_mul(PyLongObject *v, PyLongObject *w)
{
PyLongObject *a, *b, *z;
if (!convert_binop((PyObject *)v, (PyObject *)w, &a, &b)) {
if (!PyLong_Check(v) &&
v->ob_type->tp_as_sequence &&
v->ob_type->tp_as_sequence->sq_repeat)
return long_repeat((PyObject *)v, w);
if (!PyLong_Check(w) &&
w->ob_type->tp_as_sequence &&
w->ob_type->tp_as_sequence->sq_repeat)
return long_repeat((PyObject *)w, v);
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
if (Py_GETENV("KARAT") != NULL)
z = k_mul(a, b);
else
z = x_mul(a, b);
if(z == NULL) {
Py_DECREF(a);
Py_DECREF(b);
return NULL;
}
if (a->ob_size < 0)
z->ob_size = -(z->ob_size);
if (b->ob_size < 0)
@ -2003,11 +2187,6 @@ lshift_error:
/* Bitwise and/xor/or operations */
#undef MIN
#undef MAX
#define MAX(x, y) ((x) < (y) ? (y) : (x))
#define MIN(x, y) ((x) > (y) ? (y) : (x))
static PyObject *
long_bitwise(PyLongObject *a,
int op, /* '&', '|', '^' */