Issue #2819: Add math.sum, a function that sums a sequence of floats

efficiently but with no intermediate loss of precision.  Based on
Raymond Hettinger's ASPN recipe.  Thanks Jean Brouwers for the patch.
This commit is contained in:
Mark Dickinson 2008-05-23 01:35:30 +00:00
parent cc858ccc50
commit 99dfe92759
2 changed files with 226 additions and 0 deletions

View File

@ -36,6 +36,9 @@ Core and Builtins
Extension Modules
-----------------
- Issue #2819: add full-precision summation function to math module,
based on Hettinger's ASPN Python Cookbook recipe.
- Issue #2592: delegate nb_index and the floor/truediv slots in
weakref.proxy.

View File

@ -307,6 +307,228 @@ FUNC1(tan, tan, 0,
FUNC1(tanh, tanh, 0,
"tanh(x)\n\nReturn the hyperbolic tangent of x.")
/* Precision summation function as msum() by Raymond Hettinger in
<http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/393090>,
enhanced with the exact partials sum and roundoff from Mark
Dickinson's post at <http://bugs.python.org/file10357/msum4.py>.
See both of those for more details, proofs and other references.
Note 1: IEEE 754 floating point format and semantics are assumed, but not
explicitly maintained. The following rules may not apply:
1. if the summands include a NaN, return a NaN,
2. if the summands include infinities of both signs, raise ValueError,
3. if the summands include infinities of only one sign, return infinity
with that sign,
4. otherwise (all summands are finite) if the result is infinite, raise
OverflowError. The result can never be a NaN if all summands are
finite.
Note 2: the implementation below not include the intermediate overflow
handling from Mark Dickinson's msum(). Therefore, sum([1e+308, 1e-308,
1e+308]) returns result 1e+308, however sum([1e+308, 1e+308, 1e-308])
raises an OverflowError due to intermediate overflow of the first
partial sum.
Note 3: aggressively optimizing compilers may eliminate the roundoff
expressions critical for accurate summation. For example, the compiler
may optimize the following expressions
hi = x + y;
lo = y - (hi - x);
to
hi = x + y;
lo = 0.0;
defeating the whole purpose. Using volatile variables and/or explicit
assignment of critical subexpressions to a volatile variable should
remedy the problem
volatile double v; // Deter compiler from algebraically optimizing
// this critical, intermediate value away
hi = x + y;
v = hi - x;
lo = y - v;
by forcing the compiler to compute the value for v. This may also help
when subexpression are not computed with the full double precision.
Note 4. the same summation functions may be in ./cmathmodule.c. Make
sure to update both when making changes.
*/
#define NUM_PARTIALS 32 /* initial partials array size, on stack */
/* Extend the partials array p[] by doubling its size.
*/
static int /* non-zero on error */
_sum_realloc(double **p_ptr, Py_ssize_t n,
double *ps, Py_ssize_t *m_ptr)
{
void *v = NULL;
Py_ssize_t m = *m_ptr;
m += m; /* double */
if (n < m && m < (PY_SSIZE_T_MAX / sizeof(double))) {
double *p = *p_ptr;
if (p == ps) {
v = PyMem_Malloc(sizeof(double) * m);
if (v != NULL)
memcpy(v, ps, sizeof(double) * n);
}
else
v = PyMem_Realloc(p, sizeof(double) * m);
}
if (v == NULL) { /* size overflow or no memory */
PyErr_SetString(PyExc_MemoryError, "math sum partials");
return 1;
}
*p_ptr = (double*) v;
*m_ptr = m;
return 0;
}
/* Full precision summation of a sequence of floats.
def msum(iterable):
partials = [] # sorted, non-overlapping partial sums
for x in iterable:
i = 0
for y in partials:
if abs(x) < abs(y):
x, y = y, x
hi = x + y
lo = y - (hi - x)
if lo:
partials[i] = lo
i += 1
x = hi
partials[i:] = [x]
return sum_exact(partials)
Rounded x+y stored in hi with the roundoff stored in lo. Together hi+lo
are exactly equal to x+y. The inner loop applies hi/lo summation to each
partial so that the list of partial sums remains exact.
Sum_exact() adds the partial sums exactly and correctly rounds the final
result (using the round-half-to-even rule). The items in partials remain
non-zero, non-special, non-overlapping and strictly increasing in
magnitude, but possibly not all having the same sign.
Depends on IEEE 754 arithmetic guarantees.
*/
static PyObject*
math_sum(PyObject *self, PyObject *seq)
{
PyObject *item, *iter, *sum = NULL;
Py_ssize_t i, j, n = 0, m = NUM_PARTIALS;
double x, y, hi, lo=0.0, ps[NUM_PARTIALS], *p = ps;
iter = PyObject_GetIter(seq);
if (iter == NULL)
return NULL;
PyFPE_START_PROTECT("sum", Py_DECREF(iter); return NULL)
for(;;) { /* for x in iterable */
/* some invariants */
assert(0 <= n && n <= m);
assert((m == NUM_PARTIALS && p == ps) ||
(m > NUM_PARTIALS && p != NULL));
item = PyIter_Next(iter);
if (item == NULL) {
if (PyErr_Occurred())
goto _sum_error;
else
break;
}
x = PyFloat_AsDouble(item);
Py_DECREF(item);
if (PyErr_Occurred())
goto _sum_error;
for (i = j = 0; j < n; j++) { /* for y in partials */
y = p[j];
hi = x + y;
lo = fabs(x) < fabs(y)
? x - (hi - y) /* volatile */
: y - (hi - x); /* volatile */
if (lo != 0.0)
p[i++] = lo;
x = hi;
}
/* ps[i:] = [x] */
n = i;
if (x != 0.0) {
/* if non-finite, reset partials, effectively
adding subsequent items without roundoff
and yielding correct non-finite results,
provided IEEE 754 rules are observed */
if (! Py_IS_FINITE(x))
n = 0;
else if (n >= m && _sum_realloc(&p, n, ps, &m))
goto _sum_error;
p[n++] = x;
}
}
assert(n <= m);
if (n > 0) {
hi = p[--n];
if (Py_IS_FINITE(hi)) {
/* sum_exact(ps, hi) from the top, stop
as soon as the sum becomes inexact */
while (n > 0) {
x = p[--n];
y = hi;
hi = x + y;
assert(fabs(x) < fabs(y));
lo = x - (hi - y); /* volatile */
if (lo != 0.0)
break;
}
/* round correctly if necessary */
if (n > 0 && ((lo < 0.0 && p[n-1] < 0.0) ||
(lo > 0.0 && p[n-1] > 0.0))) {
y = lo * 2.0;
x = hi + y; /* volatile */
if (y == (x - hi))
hi = x;
}
}
else { /* raise corresponding error */
errno = Py_IS_NAN(hi) ? EDOM : ERANGE;
if (is_error(hi))
goto _sum_error;
}
}
else /* default */
hi = 0.0;
sum = PyFloat_FromDouble(hi);
_sum_error:
PyFPE_END_PROTECT(hi)
Py_DECREF(iter);
if (p != ps)
PyMem_Free(p);
return sum;
}
#undef NUM_PARTIALS
PyDoc_STRVAR(math_sum_doc,
"sum(sequence)\n\n\
Return the full precision sum of a sequence of numbers.\n\
When the sequence is empty, return zero.\n\n\
For accurate results, IEEE 754 floating point format\n\
and semantics and floating point radix 2 are required.");
static PyObject *
math_trunc(PyObject *self, PyObject *number)
{
@ -760,6 +982,7 @@ static PyMethodDef math_methods[] = {
{"sin", math_sin, METH_O, math_sin_doc},
{"sinh", math_sinh, METH_O, math_sinh_doc},
{"sqrt", math_sqrt, METH_O, math_sqrt_doc},
{"sum", math_sum, METH_O, math_sum_doc},
{"tan", math_tan, METH_O, math_tan_doc},
{"tanh", math_tanh, METH_O, math_tanh_doc},
{"trunc", math_trunc, METH_O, math_trunc_doc},