mirror of https://github.com/python/cpython
bpo-46257: Convert statistics._ss() to a single pass algorithm (GH-30403)
This commit is contained in:
parent
46e4c257e7
commit
43aac29cbb
|
@ -138,7 +138,7 @@ from itertools import groupby, repeat
|
||||||
from bisect import bisect_left, bisect_right
|
from bisect import bisect_left, bisect_right
|
||||||
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
|
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
|
||||||
from operator import mul
|
from operator import mul
|
||||||
from collections import Counter, namedtuple
|
from collections import Counter, namedtuple, defaultdict
|
||||||
|
|
||||||
_SQRT2 = sqrt(2.0)
|
_SQRT2 = sqrt(2.0)
|
||||||
|
|
||||||
|
@ -202,6 +202,43 @@ def _sum(data):
|
||||||
return (T, total, count)
|
return (T, total, count)
|
||||||
|
|
||||||
|
|
||||||
|
def _ss(data, c=None):
|
||||||
|
"""Return sum of square deviations of sequence data.
|
||||||
|
|
||||||
|
If ``c`` is None, the mean is calculated in one pass, and the deviations
|
||||||
|
from the mean are calculated in a second pass. Otherwise, deviations are
|
||||||
|
calculated from ``c`` as given. Use the second case with care, as it can
|
||||||
|
lead to garbage results.
|
||||||
|
"""
|
||||||
|
if c is not None:
|
||||||
|
T, total, count = _sum((d := x - c) * d for x in data)
|
||||||
|
return (T, total, count)
|
||||||
|
count = 0
|
||||||
|
sx_partials = defaultdict(int)
|
||||||
|
sxx_partials = defaultdict(int)
|
||||||
|
T = int
|
||||||
|
for typ, values in groupby(data, type):
|
||||||
|
T = _coerce(T, typ) # or raise TypeError
|
||||||
|
for n, d in map(_exact_ratio, values):
|
||||||
|
count += 1
|
||||||
|
sx_partials[d] += n
|
||||||
|
sxx_partials[d] += n * n
|
||||||
|
if not count:
|
||||||
|
total = Fraction(0)
|
||||||
|
elif None in sx_partials:
|
||||||
|
# The sum will be a NAN or INF. We can ignore all the finite
|
||||||
|
# partials, and just look at this special one.
|
||||||
|
total = sx_partials[None]
|
||||||
|
assert not _isfinite(total)
|
||||||
|
else:
|
||||||
|
sx = sum(Fraction(n, d) for d, n in sx_partials.items())
|
||||||
|
sxx = sum(Fraction(n, d*d) for d, n in sxx_partials.items())
|
||||||
|
# This formula has poor numeric properties for floats,
|
||||||
|
# but with fractions it is exact.
|
||||||
|
total = (count * sxx - sx * sx) / count
|
||||||
|
return (T, total, count)
|
||||||
|
|
||||||
|
|
||||||
def _isfinite(x):
|
def _isfinite(x):
|
||||||
try:
|
try:
|
||||||
return x.is_finite() # Likely a Decimal.
|
return x.is_finite() # Likely a Decimal.
|
||||||
|
@ -399,13 +436,9 @@ def mean(data):
|
||||||
|
|
||||||
If ``data`` is empty, StatisticsError will be raised.
|
If ``data`` is empty, StatisticsError will be raised.
|
||||||
"""
|
"""
|
||||||
if iter(data) is data:
|
T, total, n = _sum(data)
|
||||||
data = list(data)
|
|
||||||
n = len(data)
|
|
||||||
if n < 1:
|
if n < 1:
|
||||||
raise StatisticsError('mean requires at least one data point')
|
raise StatisticsError('mean requires at least one data point')
|
||||||
T, total, count = _sum(data)
|
|
||||||
assert count == n
|
|
||||||
return _convert(total / n, T)
|
return _convert(total / n, T)
|
||||||
|
|
||||||
|
|
||||||
|
@ -776,41 +809,6 @@ def quantiles(data, *, n=4, method='exclusive'):
|
||||||
|
|
||||||
# See http://mathworld.wolfram.com/Variance.html
|
# See http://mathworld.wolfram.com/Variance.html
|
||||||
# http://mathworld.wolfram.com/SampleVariance.html
|
# http://mathworld.wolfram.com/SampleVariance.html
|
||||||
# http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
|
|
||||||
#
|
|
||||||
# Under no circumstances use the so-called "computational formula for
|
|
||||||
# variance", as that is only suitable for hand calculations with a small
|
|
||||||
# amount of low-precision data. It has terrible numeric properties.
|
|
||||||
#
|
|
||||||
# See a comparison of three computational methods here:
|
|
||||||
# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/
|
|
||||||
|
|
||||||
def _ss(data, c=None):
|
|
||||||
"""Return sum of square deviations of sequence data.
|
|
||||||
|
|
||||||
If ``c`` is None, the mean is calculated in one pass, and the deviations
|
|
||||||
from the mean are calculated in a second pass. Otherwise, deviations are
|
|
||||||
calculated from ``c`` as given. Use the second case with care, as it can
|
|
||||||
lead to garbage results.
|
|
||||||
"""
|
|
||||||
if c is not None:
|
|
||||||
T, total, count = _sum((d := x - c) * d for x in data)
|
|
||||||
return (T, total)
|
|
||||||
T, total, count = _sum(data)
|
|
||||||
mean_n, mean_d = (total / count).as_integer_ratio()
|
|
||||||
partials = Counter()
|
|
||||||
for n, d in map(_exact_ratio, data):
|
|
||||||
diff_n = n * mean_d - d * mean_n
|
|
||||||
diff_d = d * mean_d
|
|
||||||
partials[diff_d * diff_d] += diff_n * diff_n
|
|
||||||
if None in partials:
|
|
||||||
# The sum will be a NAN or INF. We can ignore all the finite
|
|
||||||
# partials, and just look at this special one.
|
|
||||||
total = partials[None]
|
|
||||||
assert not _isfinite(total)
|
|
||||||
else:
|
|
||||||
total = sum(Fraction(n, d) for d, n in partials.items())
|
|
||||||
return (T, total)
|
|
||||||
|
|
||||||
|
|
||||||
def variance(data, xbar=None):
|
def variance(data, xbar=None):
|
||||||
|
@ -851,12 +849,9 @@ def variance(data, xbar=None):
|
||||||
Fraction(67, 108)
|
Fraction(67, 108)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if iter(data) is data:
|
T, ss, n = _ss(data, xbar)
|
||||||
data = list(data)
|
|
||||||
n = len(data)
|
|
||||||
if n < 2:
|
if n < 2:
|
||||||
raise StatisticsError('variance requires at least two data points')
|
raise StatisticsError('variance requires at least two data points')
|
||||||
T, ss = _ss(data, xbar)
|
|
||||||
return _convert(ss / (n - 1), T)
|
return _convert(ss / (n - 1), T)
|
||||||
|
|
||||||
|
|
||||||
|
@ -895,12 +890,9 @@ def pvariance(data, mu=None):
|
||||||
Fraction(13, 72)
|
Fraction(13, 72)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if iter(data) is data:
|
T, ss, n = _ss(data, mu)
|
||||||
data = list(data)
|
|
||||||
n = len(data)
|
|
||||||
if n < 1:
|
if n < 1:
|
||||||
raise StatisticsError('pvariance requires at least one data point')
|
raise StatisticsError('pvariance requires at least one data point')
|
||||||
T, ss = _ss(data, mu)
|
|
||||||
return _convert(ss / n, T)
|
return _convert(ss / n, T)
|
||||||
|
|
||||||
|
|
||||||
|
@ -913,12 +905,9 @@ def stdev(data, xbar=None):
|
||||||
1.0810874155219827
|
1.0810874155219827
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if iter(data) is data:
|
T, ss, n = _ss(data, xbar)
|
||||||
data = list(data)
|
|
||||||
n = len(data)
|
|
||||||
if n < 2:
|
if n < 2:
|
||||||
raise StatisticsError('stdev requires at least two data points')
|
raise StatisticsError('stdev requires at least two data points')
|
||||||
T, ss = _ss(data, xbar)
|
|
||||||
mss = ss / (n - 1)
|
mss = ss / (n - 1)
|
||||||
if issubclass(T, Decimal):
|
if issubclass(T, Decimal):
|
||||||
return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
|
return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
|
||||||
|
@ -934,12 +923,9 @@ def pstdev(data, mu=None):
|
||||||
0.986893273527251
|
0.986893273527251
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if iter(data) is data:
|
T, ss, n = _ss(data, mu)
|
||||||
data = list(data)
|
|
||||||
n = len(data)
|
|
||||||
if n < 1:
|
if n < 1:
|
||||||
raise StatisticsError('pstdev requires at least one data point')
|
raise StatisticsError('pstdev requires at least one data point')
|
||||||
T, ss = _ss(data, mu)
|
|
||||||
mss = ss / n
|
mss = ss / n
|
||||||
if issubclass(T, Decimal):
|
if issubclass(T, Decimal):
|
||||||
return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
|
return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
Optimized the mean, variance, and stdev functions in the statistics module.
|
||||||
|
If the input is an iterator, it is consumed in a single pass rather than
|
||||||
|
eating memory by conversion to a list. The single pass algorithm is about
|
||||||
|
twice as fast as the previous two pass code.
|
Loading…
Reference in New Issue