From 43aac29cbbb8a963a22c334b5b795d1e43417d6b Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Wed, 5 Jan 2022 07:39:10 -0800 Subject: [PATCH] bpo-46257: Convert statistics._ss() to a single pass algorithm (GH-30403) --- Lib/statistics.py | 100 ++++++++---------- .../2022-01-04-11-04-20.bpo-46257._o2ADe.rst | 4 + 2 files changed, 47 insertions(+), 57 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2022-01-04-11-04-20.bpo-46257._o2ADe.rst diff --git a/Lib/statistics.py b/Lib/statistics.py index c104571d390..eef2453bc73 100644 --- a/Lib/statistics.py +++ b/Lib/statistics.py @@ -138,7 +138,7 @@ from itertools import groupby, repeat from bisect import bisect_left, bisect_right from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum from operator import mul -from collections import Counter, namedtuple +from collections import Counter, namedtuple, defaultdict _SQRT2 = sqrt(2.0) @@ -202,6 +202,43 @@ def _sum(data): return (T, total, count) +def _ss(data, c=None): + """Return sum of square deviations of sequence data. + + If ``c`` is None, the mean is calculated in one pass, and the deviations + from the mean are calculated in a second pass. Otherwise, deviations are + calculated from ``c`` as given. Use the second case with care, as it can + lead to garbage results. + """ + if c is not None: + T, total, count = _sum((d := x - c) * d for x in data) + return (T, total, count) + count = 0 + sx_partials = defaultdict(int) + sxx_partials = defaultdict(int) + T = int + for typ, values in groupby(data, type): + T = _coerce(T, typ) # or raise TypeError + for n, d in map(_exact_ratio, values): + count += 1 + sx_partials[d] += n + sxx_partials[d] += n * n + if not count: + total = Fraction(0) + elif None in sx_partials: + # The sum will be a NAN or INF. We can ignore all the finite + # partials, and just look at this special one. + total = sx_partials[None] + assert not _isfinite(total) + else: + sx = sum(Fraction(n, d) for d, n in sx_partials.items()) + sxx = sum(Fraction(n, d*d) for d, n in sxx_partials.items()) + # This formula has poor numeric properties for floats, + # but with fractions it is exact. + total = (count * sxx - sx * sx) / count + return (T, total, count) + + def _isfinite(x): try: return x.is_finite() # Likely a Decimal. @@ -399,13 +436,9 @@ def mean(data): If ``data`` is empty, StatisticsError will be raised. """ - if iter(data) is data: - data = list(data) - n = len(data) + T, total, n = _sum(data) if n < 1: raise StatisticsError('mean requires at least one data point') - T, total, count = _sum(data) - assert count == n return _convert(total / n, T) @@ -776,41 +809,6 @@ def quantiles(data, *, n=4, method='exclusive'): # See http://mathworld.wolfram.com/Variance.html # http://mathworld.wolfram.com/SampleVariance.html -# http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance -# -# Under no circumstances use the so-called "computational formula for -# variance", as that is only suitable for hand calculations with a small -# amount of low-precision data. It has terrible numeric properties. -# -# See a comparison of three computational methods here: -# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/ - -def _ss(data, c=None): - """Return sum of square deviations of sequence data. - - If ``c`` is None, the mean is calculated in one pass, and the deviations - from the mean are calculated in a second pass. Otherwise, deviations are - calculated from ``c`` as given. Use the second case with care, as it can - lead to garbage results. - """ - if c is not None: - T, total, count = _sum((d := x - c) * d for x in data) - return (T, total) - T, total, count = _sum(data) - mean_n, mean_d = (total / count).as_integer_ratio() - partials = Counter() - for n, d in map(_exact_ratio, data): - diff_n = n * mean_d - d * mean_n - diff_d = d * mean_d - partials[diff_d * diff_d] += diff_n * diff_n - if None in partials: - # The sum will be a NAN or INF. We can ignore all the finite - # partials, and just look at this special one. - total = partials[None] - assert not _isfinite(total) - else: - total = sum(Fraction(n, d) for d, n in partials.items()) - return (T, total) def variance(data, xbar=None): @@ -851,12 +849,9 @@ def variance(data, xbar=None): Fraction(67, 108) """ - if iter(data) is data: - data = list(data) - n = len(data) + T, ss, n = _ss(data, xbar) if n < 2: raise StatisticsError('variance requires at least two data points') - T, ss = _ss(data, xbar) return _convert(ss / (n - 1), T) @@ -895,12 +890,9 @@ def pvariance(data, mu=None): Fraction(13, 72) """ - if iter(data) is data: - data = list(data) - n = len(data) + T, ss, n = _ss(data, mu) if n < 1: raise StatisticsError('pvariance requires at least one data point') - T, ss = _ss(data, mu) return _convert(ss / n, T) @@ -913,12 +905,9 @@ def stdev(data, xbar=None): 1.0810874155219827 """ - if iter(data) is data: - data = list(data) - n = len(data) + T, ss, n = _ss(data, xbar) if n < 2: raise StatisticsError('stdev requires at least two data points') - T, ss = _ss(data, xbar) mss = ss / (n - 1) if issubclass(T, Decimal): return _decimal_sqrt_of_frac(mss.numerator, mss.denominator) @@ -934,12 +923,9 @@ def pstdev(data, mu=None): 0.986893273527251 """ - if iter(data) is data: - data = list(data) - n = len(data) + T, ss, n = _ss(data, mu) if n < 1: raise StatisticsError('pstdev requires at least one data point') - T, ss = _ss(data, mu) mss = ss / n if issubclass(T, Decimal): return _decimal_sqrt_of_frac(mss.numerator, mss.denominator) diff --git a/Misc/NEWS.d/next/Library/2022-01-04-11-04-20.bpo-46257._o2ADe.rst b/Misc/NEWS.d/next/Library/2022-01-04-11-04-20.bpo-46257._o2ADe.rst new file mode 100644 index 00000000000..72ae56ec412 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2022-01-04-11-04-20.bpo-46257._o2ADe.rst @@ -0,0 +1,4 @@ +Optimized the mean, variance, and stdev functions in the statistics module. +If the input is an iterator, it is consumed in a single pass rather than +eating memory by conversion to a list. The single pass algorithm is about +twice as fast as the previous two pass code.