diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index 62d327998fd..37934f60e9c 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -2037,26 +2037,32 @@ where *max* is the largest value in the vector, compute: max * sqrt(sum((x / max) ** 2 for x in vec)) -When a maximum value is found, it is swapped to the end. This -lets us skip one loop iteration and just add 1.0 at the end. -Saving the largest value for last also helps improve accuracy. - -Kahan summation is used to improve accuracy. The *csum* -variable tracks the cumulative sum and *frac* tracks -fractional round-off error for the most recent addition. - The value of the *max* variable must be present in *vec* or should equal to 0.0 when n==0. Likewise, *max* will be INF if an infinity is present in the vec. The *found_nan* variable indicates whether some member of the *vec* is a NaN. + +To improve accuracy and to increase the number of cases where +vector_norm() is commutative, we use a variant of Neumaier +summation specialized to exploit that we always know that +|csum| >= |x|. + +The *csum* variable tracks the cumulative sum and *frac* tracks +the cumulative fractional errors at each step. Since this +variant assumes that |csum| >= |x| at each step, we establish +the precondition by starting the accumulation from 1.0 which +represents an entry equal to *max*. This also provides a nice +side benefit in that it lets us skip over a *max* entry (which +is swapped into *last*) saving us one iteration through the loop. + */ static inline double vector_norm(Py_ssize_t n, double *vec, double max, int found_nan) { - double x, csum = 0.0, oldcsum, frac = 0.0, last; + double x, csum = 1.0, oldcsum, frac = 0.0, last; Py_ssize_t i; if (Py_IS_INFINITY(max)) { @@ -2078,14 +2084,14 @@ vector_norm(Py_ssize_t n, double *vec, double max, int found_nan) last = max; } x /= max; - x = x*x - frac; + x = x*x; + assert(csum >= x); oldcsum = csum; csum += x; - frac = (csum - oldcsum) - x; + frac += (oldcsum - csum) + x; } assert(last == max); - csum += 1.0 - frac; - return max * sqrt(csum); + return max * sqrt(csum + frac); } #define NUM_STACK_ELEMS 16