From 47e52ee0c5c8868db903d476b49c3368fce4d79a Mon Sep 17 00:00:00 2001 From: Tim Peters Date: Mon, 30 Aug 2004 02:44:38 +0000 Subject: [PATCH] SF patch 936813: fast modular exponentiation This checkin is adapted from part 2 (of 3) of Trevor Perrin's patch set. BACKWARD INCOMPATIBILITY: SHIFT must now be divisible by 5. AFAIK, nobody will care. long_pow() could be complicated to worm around that, if necessary. long_pow(): - BUGFIX: This leaked the base and power when the power was negative (and so the computation delegated to float pow). - Instead of doing right-to-left exponentiation, do left-to-right. This is more efficient for small bases, which is the common case. - In addition, if the exponent is large (more than FIVEARY_CUTOFF digits), precompute [a**i % c for i in range(32)], and go left to right 5 bits at a time. l_divmod(): - The signature changed so that callers who don't want the quotient, or don't want the remainder, can pass NULL in the slot they don't want. This saves them from having to declare a vrbl for unwanted stuff, and remembering to decref it. long_mod(), long_div(), long_classic_div(): - Adjust to new l_divmod() signature, and simplified as a result. --- Include/longintrepr.h | 7 +- Misc/NEWS | 14 +- Objects/longobject.c | 306 ++++++++++++++++++++++++++---------------- 3 files changed, 211 insertions(+), 116 deletions(-) diff --git a/Include/longintrepr.h b/Include/longintrepr.h index 9ed1fe737b7..254076e4d4d 100644 --- a/Include/longintrepr.h +++ b/Include/longintrepr.h @@ -15,7 +15,8 @@ extern "C" { (at most (BASE-1)*(2*BASE+1) == MASK*(2*MASK+3)). Also, x_sub assumes that 'digit' is an unsigned type, and overflow is handled by taking the result mod 2**N for some N > SHIFT. - And, at some places it is assumed that MASK fits in an int, as well. */ + And, at some places it is assumed that MASK fits in an int, as well. + long_pow() requires that SHIFT be divisible by 5. */ typedef unsigned short digit; typedef unsigned int wdigit; /* digit widened to parameter size */ @@ -27,6 +28,10 @@ typedef BASE_TWODIGITS_TYPE stwodigits; /* signed variant of twodigits */ #define BASE ((digit)1 << SHIFT) #define MASK ((int)(BASE - 1)) +#if SHIFT % 5 != 0 +#error "longobject.c requires that SHIFT be divisible by 5" +#endif + /* Long integer representation. The absolute value of a number is equal to SUM(for i=0 through abs(ob_size)-1) ob_digit[i] * 2**(SHIFT*i) diff --git a/Misc/NEWS b/Misc/NEWS index 431b343aa06..660c49fa755 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -20,7 +20,11 @@ Core and builtins to compute 17**1000000 dropped from about 14 seconds to 9 on my box due to this much. The cutoff for Karatsuba multiplication was raised, since gradeschool multiplication got quicker, and the cutoff was - aggressively small regardless. + aggressively small regardless. The exponentiation algorithm was switched + from right-to-left to left-to-right, which is more efficient for small + bases. In addition, if the exponent is large, the algorithm now does + 5 bits (instead of 1 bit) at a time. That cut the time to compute + 17**1000000 on my box in half again, down to about 4.5 seconds. - OverflowWarning is no longer generated. PEP 237 scheduled this to occur in Python 2.3, but since OverflowWarning was disabled by default, @@ -156,6 +160,14 @@ Tools/Demos Build ----- +- Backward incompatibility: longintrepr.h now triggers a compile-time + error if SHIFT (the number of bits in a Python long "digit") isn't + divisible by 5. This new requirement allows simple code for the new + 5-bits-at-a-time long_pow() implementation. If necessary, the + restriction could be removed (by complicating long_pow(), or by + falling back to the 1-bit-at-a-time algorithm), but there are no + plans to do so. + - bug #991962: When building with --disable-toolbox-glue on Darwin no attempt to build Mac-specific modules occurs. diff --git a/Objects/longobject.c b/Objects/longobject.c index 2f6d103bfec..05c42ad47d7 100644 --- a/Objects/longobject.c +++ b/Objects/longobject.c @@ -15,6 +15,13 @@ #define KARATSUBA_CUTOFF 70 #define KARATSUBA_SQUARE_CUTOFF (2 * KARATSUBA_CUTOFF) +/* For exponentiation, use the binary left-to-right algorithm + * unless the exponent contains more than FIVEARY_CUTOFF digits. + * In that case, do 5 bits at a time. The potential drawback is that + * a table of 2**5 intermediate results is computed. + */ +#define FIVEARY_CUTOFF 8 + #define ABS(x) ((x) < 0 ? -(x) : (x)) #undef MIN @@ -2136,6 +2143,12 @@ long_mul(PyLongObject *v, PyLongObject *w) have different signs. We then subtract one from the 'div' part of the outcome to keep the invariant intact. */ +/* Compute + * *pdiv, *pmod = divmod(v, w) + * NULL can be passed for pdiv or pmod, in which case that part of + * the result is simply thrown away. The caller owns a reference to + * each of these it requests (does not pass NULL for). + */ static int l_divmod(PyLongObject *v, PyLongObject *w, PyLongObject **pdiv, PyLongObject **pmod) @@ -2167,44 +2180,43 @@ l_divmod(PyLongObject *v, PyLongObject *w, Py_DECREF(div); div = temp; } - *pdiv = div; - *pmod = mod; + if (pdiv != NULL) + *pdiv = div; + else + Py_DECREF(div); + + if (pmod != NULL) + *pmod = mod; + else + Py_DECREF(mod); + return 0; } static PyObject * long_div(PyObject *v, PyObject *w) { - PyLongObject *a, *b, *div, *mod; + PyLongObject *a, *b, *div; CONVERT_BINOP(v, w, &a, &b); - - if (l_divmod(a, b, &div, &mod) < 0) { - Py_DECREF(a); - Py_DECREF(b); - return NULL; - } + if (l_divmod(a, b, &div, NULL) < 0) + div = NULL; Py_DECREF(a); Py_DECREF(b); - Py_DECREF(mod); return (PyObject *)div; } static PyObject * long_classic_div(PyObject *v, PyObject *w) { - PyLongObject *a, *b, *div, *mod; + PyLongObject *a, *b, *div; CONVERT_BINOP(v, w, &a, &b); - if (Py_DivisionWarningFlag && PyErr_Warn(PyExc_DeprecationWarning, "classic long division") < 0) div = NULL; - else if (l_divmod(a, b, &div, &mod) < 0) + else if (l_divmod(a, b, &div, NULL) < 0) div = NULL; - else - Py_DECREF(mod); - Py_DECREF(a); Py_DECREF(b); return (PyObject *)div; @@ -2255,18 +2267,14 @@ overflow: static PyObject * long_mod(PyObject *v, PyObject *w) { - PyLongObject *a, *b, *div, *mod; + PyLongObject *a, *b, *mod; CONVERT_BINOP(v, w, &a, &b); - if (l_divmod(a, b, &div, &mod) < 0) { - Py_DECREF(a); - Py_DECREF(b); - return NULL; - } + if (l_divmod(a, b, NULL, &mod) < 0) + mod = NULL; Py_DECREF(a); Py_DECREF(b); - Py_DECREF(div); return (PyObject *)mod; } @@ -2297,22 +2305,33 @@ long_divmod(PyObject *v, PyObject *w) return z; } +/* pow(v, w, x) */ static PyObject * long_pow(PyObject *v, PyObject *w, PyObject *x) { - PyLongObject *a, *b; - PyObject *c; - PyLongObject *z, *div, *mod; - int size_b, i; + PyLongObject *a, *b, *c; /* a,b,c = v,w,x */ + int negativeOutput = 0; /* if x<0 return negative output */ + PyLongObject *z = NULL; /* accumulated result */ + int i, j, k; /* counters */ + PyLongObject *temp = NULL; + + /* 5-ary values. If the exponent is large enough, table is + * precomputed so that table[i] == a**i % c for i in range(32). + */ + PyLongObject *table[32] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; + + /* a, b, c = v, w, x */ CONVERT_BINOP(v, w, &a, &b); - if (PyLong_Check(x) || Py_None == x) { - c = x; + if (PyLong_Check(x)) { + c = (PyLongObject *)x; Py_INCREF(x); } - else if (PyInt_Check(x)) { - c = PyLong_FromLong(PyInt_AS_LONG(x)); - } + else if (PyInt_Check(x)) + c = (PyLongObject *)PyLong_FromLong(PyInt_AS_LONG(x)); + else if (x == Py_None) + c = NULL; else { Py_DECREF(a); Py_DECREF(b); @@ -2320,95 +2339,154 @@ long_pow(PyObject *v, PyObject *w, PyObject *x) return Py_NotImplemented; } - if (c != Py_None && ((PyLongObject *)c)->ob_size == 0) { - PyErr_SetString(PyExc_ValueError, - "pow() 3rd argument cannot be 0"); - z = NULL; - goto error; - } - - size_b = b->ob_size; - if (size_b < 0) { - Py_DECREF(a); - Py_DECREF(b); - Py_DECREF(c); - if (x != Py_None) { + if (b->ob_size < 0) { /* if exponent is negative */ + if (c) { PyErr_SetString(PyExc_TypeError, "pow() 2nd argument " - "cannot be negative when 3rd argument specified"); + "cannot be negative when 3rd argument specified"); return NULL; } - /* Return a float. This works because we know that - this calls float_pow() which converts its - arguments to double. */ - return PyFloat_Type.tp_as_number->nb_power(v, w, x); - } - z = (PyLongObject *)PyLong_FromLong(1L); - for (i = 0; i < size_b; ++i) { - digit bi = b->ob_digit[i]; - int j; - - for (j = 0; j < SHIFT; ++j) { - PyLongObject *temp; - - if (bi & 1) { - temp = (PyLongObject *)long_mul(z, a); - Py_DECREF(z); - if (c!=Py_None && temp!=NULL) { - if (l_divmod(temp,(PyLongObject *)c, - &div,&mod) < 0) { - Py_DECREF(temp); - z = NULL; - goto error; - } - Py_XDECREF(div); - Py_DECREF(temp); - temp = mod; - } - z = temp; - if (z == NULL) - break; - } - bi >>= 1; - if (bi == 0 && i+1 == size_b) - break; - temp = (PyLongObject *)long_mul(a, a); - Py_DECREF(a); - if (c!=Py_None && temp!=NULL) { - if (l_divmod(temp, (PyLongObject *)c, &div, - &mod) < 0) { - Py_DECREF(temp); - z = NULL; - goto error; - } - Py_XDECREF(div); - Py_DECREF(temp); - temp = mod; - } - a = temp; - if (a == NULL) { - Py_DECREF(z); - z = NULL; - break; - } - } - if (a == NULL || z == NULL) - break; - } - if (c!=Py_None && z!=NULL) { - if (l_divmod(z, (PyLongObject *)c, &div, &mod) < 0) { - Py_DECREF(z); - z = NULL; - } else { - Py_XDECREF(div); - Py_DECREF(z); - z = mod; + /* else return a float. This works because we know + that this calls float_pow() which converts its + arguments to double. */ + Py_DECREF(a); + Py_DECREF(b); + return PyFloat_Type.tp_as_number->nb_power(v, w, x); } } - error: + + if (c) { + /* if modulus == 0: + raise ValueError() */ + if (c->ob_size == 0) { + PyErr_SetString(PyExc_ValueError, + "pow() 3rd argument cannot be 0"); + goto Done; + } + + /* if modulus < 0: + negativeOutput = True + modulus = -modulus */ + if (c->ob_size < 0) { + negativeOutput = 1; + temp = (PyLongObject *)_PyLong_Copy(c); + if (temp == NULL) + goto Error; + Py_DECREF(c); + c = temp; + temp = NULL; + c->ob_size = - c->ob_size; + } + + /* if modulus == 1: + return 0 */ + if ((c->ob_size == 1) && (c->ob_digit[0] == 1)) { + z = (PyLongObject *)PyLong_FromLong(0L); + goto Done; + } + + /* if base < 0: + base = base % modulus + Having the base positive just makes things easier. */ + if (a->ob_size < 0) { + if (l_divmod(a, c, NULL, &temp) < 0) + goto Done; + Py_DECREF(a); + a = temp; + temp = NULL; + } + } + + /* At this point a, b, and c are guaranteed non-negative UNLESS + c is NULL, in which case a may be negative. */ + + z = (PyLongObject *)PyLong_FromLong(1L); + if (z == NULL) + goto Error; + + /* Perform a modular reduction, X = X % c, but leave X alone if c + * is NULL. + */ +#define REDUCE(X) \ + if (c != NULL) { \ + if (l_divmod(X, c, NULL, &temp) < 0) \ + goto Error; \ + Py_XDECREF(X); \ + X = temp; \ + temp = NULL; \ + } + + /* Multiply two values, then reduce the result: + result = X*Y % c. If c is NULL, skip the mod. */ +#define MULT(X, Y, result) \ +{ \ + temp = (PyLongObject *)long_mul(X, Y); \ + if (temp == NULL) \ + goto Error; \ + Py_XDECREF(result); \ + result = temp; \ + temp = NULL; \ + REDUCE(result) \ +} + + if (b->ob_size <= FIVEARY_CUTOFF) { + /* Left-to-right binary exponentiation (HAC Algorithm 14.79) */ + /* http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf */ + for (i = b->ob_size - 1; i >= 0; --i) { + digit bi = b->ob_digit[i]; + + for (j = 1 << (SHIFT-1); j != 0; j >>= 1) { + MULT(z, z, z) + if (bi & j) + MULT(z, a, z) + } + } + } + else { + /* Left-to-right 5-ary exponentiation (HAC Algorithm 14.82) */ + Py_INCREF(z); /* still holds 1L */ + table[0] = z; + for (i = 1; i < 32; ++i) + MULT(table[i-1], a, table[i]) + + for (i = b->ob_size - 1; i >= 0; --i) { + const digit bi = b->ob_digit[i]; + + for (j = SHIFT - 5; j >= 0; j -= 5) { + const int index = (bi >> j) & 0x1f; + for (k = 0; k < 5; ++k) + MULT(z, z, z) + if (index) + MULT(z, table[index], z) + } + } + } + + if (negativeOutput && (z->ob_size != 0)) { + temp = (PyLongObject *)long_sub(z, c); + if (temp == NULL) + goto Error; + Py_DECREF(z); + z = temp; + temp = NULL; + } + goto Done; + + Error: + if (z != NULL) { + Py_DECREF(z); + z = NULL; + } + /* fall through */ + Done: Py_XDECREF(a); - Py_DECREF(b); - Py_DECREF(c); + Py_XDECREF(b); + Py_XDECREF(c); + Py_XDECREF(temp); + if (b->ob_size > FIVEARY_CUTOFF) { + for (i = 0; i < 32; ++i) + Py_XDECREF(table[i]); + } return (PyObject *)z; }