Improve speed and accuracy for correlation() (GH-26135)

This commit is contained in:
Raymond Hettinger 2021-05-15 11:00:51 -07:00 committed by GitHub
parent 80b089179f
commit fdfea4ab16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 12 deletions

View File

@ -107,9 +107,12 @@ A single exception is defined: StatisticsError is a subclass of ValueError.
__all__ = [ __all__ = [
'NormalDist', 'NormalDist',
'StatisticsError', 'StatisticsError',
'correlation',
'covariance',
'fmean', 'fmean',
'geometric_mean', 'geometric_mean',
'harmonic_mean', 'harmonic_mean',
'linear_regression',
'mean', 'mean',
'median', 'median',
'median_grouped', 'median_grouped',
@ -122,9 +125,6 @@ __all__ = [
'quantiles', 'quantiles',
'stdev', 'stdev',
'variance', 'variance',
'correlation',
'covariance',
'linear_regression',
] ]
import math import math
@ -882,10 +882,10 @@ def covariance(x, y, /):
raise StatisticsError('covariance requires that both inputs have same number of data points') raise StatisticsError('covariance requires that both inputs have same number of data points')
if n < 2: if n < 2:
raise StatisticsError('covariance requires at least two data points') raise StatisticsError('covariance requires at least two data points')
xbar = fmean(x) xbar = fsum(x) / n
ybar = fmean(y) ybar = fsum(y) / n
total = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y)) sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
return total / (n - 1) return sxy / (n - 1)
def correlation(x, y, /): def correlation(x, y, /):
@ -910,11 +910,13 @@ def correlation(x, y, /):
raise StatisticsError('correlation requires that both inputs have same number of data points') raise StatisticsError('correlation requires that both inputs have same number of data points')
if n < 2: if n < 2:
raise StatisticsError('correlation requires at least two data points') raise StatisticsError('correlation requires at least two data points')
cov = covariance(x, y) xbar = fsum(x) / n
stdx = stdev(x) ybar = fsum(y) / n
stdy = stdev(y) sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
s2x = fsum((xi - xbar) ** 2.0 for xi in x)
s2y = fsum((yi - ybar) ** 2.0 for yi in y)
try: try:
return cov / (stdx * stdy) return sxy / sqrt(s2x * s2y)
except ZeroDivisionError: except ZeroDivisionError:
raise StatisticsError('at least one of the inputs is constant') raise StatisticsError('at least one of the inputs is constant')
@ -958,7 +960,7 @@ def linear_regression(regressor, dependent_variable, /):
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y)) sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
s2x = fsum((xi - xbar) ** 2.0 for xi in x) s2x = fsum((xi - xbar) ** 2.0 for xi in x)
try: try:
slope = sxy / s2x slope = sxy / s2x # equivalent to: covariance(x, y) / variance(x)
except ZeroDivisionError: except ZeroDivisionError:
raise StatisticsError('regressor is constant') raise StatisticsError('regressor is constant')
intercept = ybar - slope * xbar intercept = ybar - slope * xbar