bpo-46257: Convert statistics._ss() to a single pass algorithm (GH-30403)

This commit is contained in:
Raymond Hettinger 2022-01-05 07:39:10 -08:00 committed by GitHub
parent 46e4c257e7
commit 43aac29cbb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 57 deletions

View File

@ -138,7 +138,7 @@ from itertools import groupby, repeat
from bisect import bisect_left, bisect_right from bisect import bisect_left, bisect_right
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
from operator import mul from operator import mul
from collections import Counter, namedtuple from collections import Counter, namedtuple, defaultdict
_SQRT2 = sqrt(2.0) _SQRT2 = sqrt(2.0)
@ -202,6 +202,43 @@ def _sum(data):
return (T, total, count) return (T, total, count)
def _ss(data, c=None):
"""Return sum of square deviations of sequence data.
If ``c`` is None, the mean is calculated in one pass, and the deviations
from the mean are calculated in a second pass. Otherwise, deviations are
calculated from ``c`` as given. Use the second case with care, as it can
lead to garbage results.
"""
if c is not None:
T, total, count = _sum((d := x - c) * d for x in data)
return (T, total, count)
count = 0
sx_partials = defaultdict(int)
sxx_partials = defaultdict(int)
T = int
for typ, values in groupby(data, type):
T = _coerce(T, typ) # or raise TypeError
for n, d in map(_exact_ratio, values):
count += 1
sx_partials[d] += n
sxx_partials[d] += n * n
if not count:
total = Fraction(0)
elif None in sx_partials:
# The sum will be a NAN or INF. We can ignore all the finite
# partials, and just look at this special one.
total = sx_partials[None]
assert not _isfinite(total)
else:
sx = sum(Fraction(n, d) for d, n in sx_partials.items())
sxx = sum(Fraction(n, d*d) for d, n in sxx_partials.items())
# This formula has poor numeric properties for floats,
# but with fractions it is exact.
total = (count * sxx - sx * sx) / count
return (T, total, count)
def _isfinite(x): def _isfinite(x):
try: try:
return x.is_finite() # Likely a Decimal. return x.is_finite() # Likely a Decimal.
@ -399,13 +436,9 @@ def mean(data):
If ``data`` is empty, StatisticsError will be raised. If ``data`` is empty, StatisticsError will be raised.
""" """
if iter(data) is data: T, total, n = _sum(data)
data = list(data)
n = len(data)
if n < 1: if n < 1:
raise StatisticsError('mean requires at least one data point') raise StatisticsError('mean requires at least one data point')
T, total, count = _sum(data)
assert count == n
return _convert(total / n, T) return _convert(total / n, T)
@ -776,41 +809,6 @@ def quantiles(data, *, n=4, method='exclusive'):
# See http://mathworld.wolfram.com/Variance.html # See http://mathworld.wolfram.com/Variance.html
# http://mathworld.wolfram.com/SampleVariance.html # http://mathworld.wolfram.com/SampleVariance.html
# http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
#
# Under no circumstances use the so-called "computational formula for
# variance", as that is only suitable for hand calculations with a small
# amount of low-precision data. It has terrible numeric properties.
#
# See a comparison of three computational methods here:
# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/
def _ss(data, c=None):
"""Return sum of square deviations of sequence data.
If ``c`` is None, the mean is calculated in one pass, and the deviations
from the mean are calculated in a second pass. Otherwise, deviations are
calculated from ``c`` as given. Use the second case with care, as it can
lead to garbage results.
"""
if c is not None:
T, total, count = _sum((d := x - c) * d for x in data)
return (T, total)
T, total, count = _sum(data)
mean_n, mean_d = (total / count).as_integer_ratio()
partials = Counter()
for n, d in map(_exact_ratio, data):
diff_n = n * mean_d - d * mean_n
diff_d = d * mean_d
partials[diff_d * diff_d] += diff_n * diff_n
if None in partials:
# The sum will be a NAN or INF. We can ignore all the finite
# partials, and just look at this special one.
total = partials[None]
assert not _isfinite(total)
else:
total = sum(Fraction(n, d) for d, n in partials.items())
return (T, total)
def variance(data, xbar=None): def variance(data, xbar=None):
@ -851,12 +849,9 @@ def variance(data, xbar=None):
Fraction(67, 108) Fraction(67, 108)
""" """
if iter(data) is data: T, ss, n = _ss(data, xbar)
data = list(data)
n = len(data)
if n < 2: if n < 2:
raise StatisticsError('variance requires at least two data points') raise StatisticsError('variance requires at least two data points')
T, ss = _ss(data, xbar)
return _convert(ss / (n - 1), T) return _convert(ss / (n - 1), T)
@ -895,12 +890,9 @@ def pvariance(data, mu=None):
Fraction(13, 72) Fraction(13, 72)
""" """
if iter(data) is data: T, ss, n = _ss(data, mu)
data = list(data)
n = len(data)
if n < 1: if n < 1:
raise StatisticsError('pvariance requires at least one data point') raise StatisticsError('pvariance requires at least one data point')
T, ss = _ss(data, mu)
return _convert(ss / n, T) return _convert(ss / n, T)
@ -913,12 +905,9 @@ def stdev(data, xbar=None):
1.0810874155219827 1.0810874155219827
""" """
if iter(data) is data: T, ss, n = _ss(data, xbar)
data = list(data)
n = len(data)
if n < 2: if n < 2:
raise StatisticsError('stdev requires at least two data points') raise StatisticsError('stdev requires at least two data points')
T, ss = _ss(data, xbar)
mss = ss / (n - 1) mss = ss / (n - 1)
if issubclass(T, Decimal): if issubclass(T, Decimal):
return _decimal_sqrt_of_frac(mss.numerator, mss.denominator) return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
@ -934,12 +923,9 @@ def pstdev(data, mu=None):
0.986893273527251 0.986893273527251
""" """
if iter(data) is data: T, ss, n = _ss(data, mu)
data = list(data)
n = len(data)
if n < 1: if n < 1:
raise StatisticsError('pstdev requires at least one data point') raise StatisticsError('pstdev requires at least one data point')
T, ss = _ss(data, mu)
mss = ss / n mss = ss / n
if issubclass(T, Decimal): if issubclass(T, Decimal):
return _decimal_sqrt_of_frac(mss.numerator, mss.denominator) return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)

View File

@ -0,0 +1,4 @@
Optimized the mean, variance, and stdev functions in the statistics module.
If the input is an iterator, it is consumed in a single pass rather than
eating memory by conversion to a list. The single pass algorithm is about
twice as fast as the previous two pass code.