Improve speed and accuracy for correlation() (GH-26135)

This commit is contained in:
Raymond Hettinger 2021-05-15 11:00:51 -07:00 committed by GitHub
parent 80b089179f
commit fdfea4ab16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 12 deletions

View File

@ -107,9 +107,12 @@ A single exception is defined: StatisticsError is a subclass of ValueError.
__all__ = [
'NormalDist',
'StatisticsError',
'correlation',
'covariance',
'fmean',
'geometric_mean',
'harmonic_mean',
'linear_regression',
'mean',
'median',
'median_grouped',
@ -122,9 +125,6 @@ __all__ = [
'quantiles',
'stdev',
'variance',
'correlation',
'covariance',
'linear_regression',
]
import math
@ -882,10 +882,10 @@ def covariance(x, y, /):
raise StatisticsError('covariance requires that both inputs have same number of data points')
if n < 2:
raise StatisticsError('covariance requires at least two data points')
xbar = fmean(x)
ybar = fmean(y)
total = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
return total / (n - 1)
xbar = fsum(x) / n
ybar = fsum(y) / n
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
return sxy / (n - 1)
def correlation(x, y, /):
@ -910,11 +910,13 @@ def correlation(x, y, /):
raise StatisticsError('correlation requires that both inputs have same number of data points')
if n < 2:
raise StatisticsError('correlation requires at least two data points')
cov = covariance(x, y)
stdx = stdev(x)
stdy = stdev(y)
xbar = fsum(x) / n
ybar = fsum(y) / n
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
s2x = fsum((xi - xbar) ** 2.0 for xi in x)
s2y = fsum((yi - ybar) ** 2.0 for yi in y)
try:
return cov / (stdx * stdy)
return sxy / sqrt(s2x * s2y)
except ZeroDivisionError:
raise StatisticsError('at least one of the inputs is constant')
@ -958,7 +960,7 @@ def linear_regression(regressor, dependent_variable, /):
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
s2x = fsum((xi - xbar) ** 2.0 for xi in x)
try:
slope = sxy / s2x
slope = sxy / s2x # equivalent to: covariance(x, y) / variance(x)
except ZeroDivisionError:
raise StatisticsError('regressor is constant')
intercept = ybar - slope * xbar