mirror of https://github.com/python/cpython
Improve speed and accuracy for correlation() (GH-26135)
This commit is contained in:
parent
80b089179f
commit
fdfea4ab16
|
@ -107,9 +107,12 @@ A single exception is defined: StatisticsError is a subclass of ValueError.
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'NormalDist',
|
'NormalDist',
|
||||||
'StatisticsError',
|
'StatisticsError',
|
||||||
|
'correlation',
|
||||||
|
'covariance',
|
||||||
'fmean',
|
'fmean',
|
||||||
'geometric_mean',
|
'geometric_mean',
|
||||||
'harmonic_mean',
|
'harmonic_mean',
|
||||||
|
'linear_regression',
|
||||||
'mean',
|
'mean',
|
||||||
'median',
|
'median',
|
||||||
'median_grouped',
|
'median_grouped',
|
||||||
|
@ -122,9 +125,6 @@ __all__ = [
|
||||||
'quantiles',
|
'quantiles',
|
||||||
'stdev',
|
'stdev',
|
||||||
'variance',
|
'variance',
|
||||||
'correlation',
|
|
||||||
'covariance',
|
|
||||||
'linear_regression',
|
|
||||||
]
|
]
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
@ -882,10 +882,10 @@ def covariance(x, y, /):
|
||||||
raise StatisticsError('covariance requires that both inputs have same number of data points')
|
raise StatisticsError('covariance requires that both inputs have same number of data points')
|
||||||
if n < 2:
|
if n < 2:
|
||||||
raise StatisticsError('covariance requires at least two data points')
|
raise StatisticsError('covariance requires at least two data points')
|
||||||
xbar = fmean(x)
|
xbar = fsum(x) / n
|
||||||
ybar = fmean(y)
|
ybar = fsum(y) / n
|
||||||
total = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
|
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
|
||||||
return total / (n - 1)
|
return sxy / (n - 1)
|
||||||
|
|
||||||
|
|
||||||
def correlation(x, y, /):
|
def correlation(x, y, /):
|
||||||
|
@ -910,11 +910,13 @@ def correlation(x, y, /):
|
||||||
raise StatisticsError('correlation requires that both inputs have same number of data points')
|
raise StatisticsError('correlation requires that both inputs have same number of data points')
|
||||||
if n < 2:
|
if n < 2:
|
||||||
raise StatisticsError('correlation requires at least two data points')
|
raise StatisticsError('correlation requires at least two data points')
|
||||||
cov = covariance(x, y)
|
xbar = fsum(x) / n
|
||||||
stdx = stdev(x)
|
ybar = fsum(y) / n
|
||||||
stdy = stdev(y)
|
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
|
||||||
|
s2x = fsum((xi - xbar) ** 2.0 for xi in x)
|
||||||
|
s2y = fsum((yi - ybar) ** 2.0 for yi in y)
|
||||||
try:
|
try:
|
||||||
return cov / (stdx * stdy)
|
return sxy / sqrt(s2x * s2y)
|
||||||
except ZeroDivisionError:
|
except ZeroDivisionError:
|
||||||
raise StatisticsError('at least one of the inputs is constant')
|
raise StatisticsError('at least one of the inputs is constant')
|
||||||
|
|
||||||
|
@ -958,7 +960,7 @@ def linear_regression(regressor, dependent_variable, /):
|
||||||
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
|
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
|
||||||
s2x = fsum((xi - xbar) ** 2.0 for xi in x)
|
s2x = fsum((xi - xbar) ** 2.0 for xi in x)
|
||||||
try:
|
try:
|
||||||
slope = sxy / s2x
|
slope = sxy / s2x # equivalent to: covariance(x, y) / variance(x)
|
||||||
except ZeroDivisionError:
|
except ZeroDivisionError:
|
||||||
raise StatisticsError('regressor is constant')
|
raise StatisticsError('regressor is constant')
|
||||||
intercept = ybar - slope * xbar
|
intercept = ybar - slope * xbar
|
||||||
|
|
Loading…
Reference in New Issue