Skip to content

Commit 5442cfa

Browse files
Improve speed and accuracy for correlation() (GH-26135) (GH-26151)
1 parent 0d441d2 commit 5442cfa

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

Lib/statistics.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,12 @@
107107
__all__ = [
108108
'NormalDist',
109109
'StatisticsError',
110+
'correlation',
111+
'covariance',
110112
'fmean',
111113
'geometric_mean',
112114
'harmonic_mean',
115+
'linear_regression',
113116
'mean',
114117
'median',
115118
'median_grouped',
@@ -122,9 +125,6 @@
122125
'quantiles',
123126
'stdev',
124127
'variance',
125-
'correlation',
126-
'covariance',
127-
'linear_regression',
128128
]
129129

130130
import math
@@ -882,10 +882,10 @@ def covariance(x, y, /):
882882
raise StatisticsError('covariance requires that both inputs have same number of data points')
883883
if n < 2:
884884
raise StatisticsError('covariance requires at least two data points')
885-
xbar = fmean(x)
886-
ybar = fmean(y)
887-
total = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
888-
return total / (n - 1)
885+
xbar = fsum(x) / n
886+
ybar = fsum(y) / n
887+
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
888+
return sxy / (n - 1)
889889

890890

891891
def correlation(x, y, /):
@@ -910,11 +910,13 @@ def correlation(x, y, /):
910910
raise StatisticsError('correlation requires that both inputs have same number of data points')
911911
if n < 2:
912912
raise StatisticsError('correlation requires at least two data points')
913-
cov = covariance(x, y)
914-
stdx = stdev(x)
915-
stdy = stdev(y)
913+
xbar = fsum(x) / n
914+
ybar = fsum(y) / n
915+
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
916+
s2x = fsum((xi - xbar) ** 2.0 for xi in x)
917+
s2y = fsum((yi - ybar) ** 2.0 for yi in y)
916918
try:
917-
return cov / (stdx * stdy)
919+
return sxy / sqrt(s2x * s2y)
918920
except ZeroDivisionError:
919921
raise StatisticsError('at least one of the inputs is constant')
920922

@@ -957,7 +959,7 @@ def linear_regression(x, y, /):
957959
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
958960
s2x = fsum((xi - xbar) ** 2.0 for xi in x)
959961
try:
960-
slope = sxy / s2x
962+
slope = sxy / s2x # equivalent to: covariance(x, y) / variance(x)
961963
except ZeroDivisionError:
962964
raise StatisticsError('x is constant')
963965
intercept = ybar - slope * xbar

0 commit comments

Comments
 (0)