Skip to content

Commit 17a0b99

Browse files
Pearson correlation (#5)
* Signature of function has been provided. * Added unimplemented sketch to trait implementation. * Added new test module for pearson correlation tests. * Added a first test for pearson correlation. * Depending on master branch of ndarray. * Basic implementation of pearson_correlation. * Improved docstring for pearson_correlation. * Changed test name - no need to repeat, given test module name. * Check what happens with constant random variables. * Removed double ndarray patch. * Changed ddof in pearson to avoid panic. * Fixed docs for Pearson Correlation - zero division panic. * Test Pearson correlation on random input. * Remove println statement in test. * Debugging. * Fix typo * Adding assertion in zero_observations case for covariance. * Simplified one test. * Fix typo in test * Fixing documentation typo * Adding docs for the whole trait * Added doc-test to pearson correlation method. * Added Example header before doc-test
1 parent 5f12a2e commit 17a0b99

File tree

1 file changed

+162
-14
lines changed

1 file changed

+162
-14
lines changed

src/correlation.rs

Lines changed: 162 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ use ndarray::prelude::*;
22
use ndarray::Data;
33
use num_traits::{Float, FromPrimitive};
44

5+
/// Extension trait for ArrayBase providing functions
6+
/// to compute different correlation measures.
57
pub trait CorrelationExt<A, S>
68
where
79
S: Data<Elem = A>,
@@ -11,13 +13,13 @@ where
1113
///
1214
/// Let `(r, o)` be the shape of `M`:
1315
/// - `r` is the number of random variables;
14-
/// - `o` is the number of observations we have collected
16+
/// - `o` is the number of observations we have collected
1517
/// for each random variable.
16-
///
17-
/// Every column in `M` is an experiment: a single observation for each
18+
///
19+
/// Every column in `M` is an experiment: a single observation for each
1820
/// random variable.
1921
/// Each row in `M` contains all the observations for a certain random variable.
20-
///
22+
///
2123
/// The parameter `ddof` specifies the "delta degrees of freedom". For
2224
/// example, to calculate the population covariance, use `ddof = 0`, or to
2325
/// calculate the sample covariance (unbiased estimate), use `ddof = 1`.
@@ -37,7 +39,7 @@ where
3739
/// x̅ = ― ∑ xᵢ
3840
/// n i=1
3941
/// ```
40-
/// and similarly for ̅y.
42+
/// and similarly for ̅y.
4143
///
4244
/// **Panics** if `ddof` is greater than or equal to the number of
4345
/// observations, if the number of observations is zero and division by
@@ -56,11 +58,65 @@ where
5658
/// [2., 4., 6.]]);
5759
/// let covariance = a.cov(1.);
5860
/// assert_eq!(
59-
/// covariance,
61+
/// covariance,
6062
/// aview2(&[[4., 4.], [4., 4.]])
6163
/// );
6264
/// ```
63-
fn cov(&self, ddof: A) -> Array2<A>
65+
fn cov(&self, ddof: A) -> Array2<A>
66+
where
67+
A: Float + FromPrimitive;
68+
69+
/// Return the [Pearson correlation coefficients](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient)
70+
/// for a 2-dimensional array of observations `M`.
71+
///
72+
/// Let `(r, o)` be the shape of `M`:
73+
/// - `r` is the number of random variables;
74+
/// - `o` is the number of observations we have collected
75+
/// for each random variable.
76+
///
77+
/// Every column in `M` is an experiment: a single observation for each
78+
/// random variable.
79+
/// Each row in `M` contains all the observations for a certain random variable.
80+
///
81+
/// The Pearson correlation coefficient of two random variables is defined as:
82+
///
83+
/// ```text
84+
/// cov(X, Y)
85+
/// rho(X, Y) = ――――――――――――
86+
/// std(X)std(Y)
87+
/// ```
88+
///
89+
/// Let `R` be the matrix returned by this function. Then
90+
/// ```text
91+
/// R_ij = rho(X_i, X_j)
92+
/// ```
93+
///
94+
/// **Panics** if `M` is empty, if the type cast of `n_observations`
95+
/// from `usize` to `A` fails or if the standard deviation of one of the random
96+
///
97+
/// # Example
98+
///
99+
/// variables is zero and division by zero panics for type A.
100+
/// ```
101+
/// extern crate ndarray;
102+
/// extern crate ndarray_stats;
103+
/// use ndarray::arr2;
104+
/// use ndarray_stats::CorrelationExt;
105+
///
106+
/// let a = arr2(&[[1., 3., 5.],
107+
/// [2., 4., 6.]]);
108+
/// let corr = a.pearson_correlation();
109+
/// assert!(
110+
/// corr.all_close(
111+
/// &arr2(&[
112+
/// [1., 1.],
113+
/// [1., 1.],
114+
/// ]),
115+
/// 1e-7
116+
/// )
117+
/// );
118+
/// ```
119+
fn pearson_correlation(&self) -> Array2<A>
64120
where
65121
A: Float + FromPrimitive;
66122
}
@@ -75,7 +131,7 @@ where
75131
{
76132
let observation_axis = Axis(1);
77133
let n_observations = A::from_usize(self.len_of(observation_axis)).unwrap();
78-
let dof =
134+
let dof =
79135
if ddof >= n_observations {
80136
panic!("`ddof` needs to be strictly smaller than the \
81137
number of observations provided for each \
@@ -88,16 +144,33 @@ where
88144
let covariance = denoised.dot(&denoised.t());
89145
covariance.mapv_into(|x| x / dof)
90146
}
147+
148+
fn pearson_correlation(&self) -> Array2<A>
149+
where
150+
A: Float + FromPrimitive,
151+
{
152+
let observation_axis = Axis(1);
153+
// The ddof value doesn't matter, as long as we use the same one
154+
// for computing covariance and standard deviation
155+
// We choose -1 to avoid panicking when we only have one
156+
// observation per random variable (or no observations at all)
157+
let ddof = -A::one();
158+
let cov = self.cov(ddof);
159+
let std = self.std_axis(observation_axis, ddof).insert_axis(observation_axis);
160+
let std_matrix = std.dot(&std.t());
161+
// element-wise division
162+
cov / std_matrix
163+
}
91164
}
92165

93166
#[cfg(test)]
94-
mod tests {
167+
mod cov_tests {
95168
use super::*;
96169
use rand;
97170
use rand::distributions::Range;
98171
use ndarray_rand::RandomExt;
99172

100-
quickcheck! {
173+
quickcheck! {
101174
fn constant_random_variables_have_zero_covariance_matrix(value: f64) -> bool {
102175
let n_random_variables = 3;
103176
let n_observations = 4;
@@ -112,21 +185,21 @@ mod tests {
112185
let n_random_variables = 3;
113186
let n_observations = 4;
114187
let a = Array::random(
115-
(n_random_variables, n_observations),
188+
(n_random_variables, n_observations),
116189
Range::new(-bound.abs(), bound.abs())
117190
);
118191
let covariance = a.cov(1.);
119192
covariance.all_close(&covariance.t(), 1e-8)
120193
}
121194
}
122-
195+
123196
#[test]
124197
#[should_panic]
125198
fn test_invalid_ddof() {
126199
let n_random_variables = 3;
127200
let n_observations = 4;
128201
let a = Array::random(
129-
(n_random_variables, n_observations),
202+
(n_random_variables, n_observations),
130203
Range::new(0., 10.)
131204
);
132205
let invalid_ddof = (n_observations as f64) + rand::random::<f64>().abs();
@@ -200,4 +273,79 @@ mod tests {
200273
)
201274
);
202275
}
203-
}
276+
}
277+
278+
#[cfg(test)]
279+
mod pearson_correlation_tests {
280+
use super::*;
281+
use rand::distributions::Range;
282+
use ndarray_rand::RandomExt;
283+
284+
quickcheck! {
285+
fn output_matrix_is_symmetric(bound: f64) -> bool {
286+
let n_random_variables = 3;
287+
let n_observations = 4;
288+
let a = Array::random(
289+
(n_random_variables, n_observations),
290+
Range::new(-bound.abs(), bound.abs())
291+
);
292+
let pearson_correlation = a.pearson_correlation();
293+
pearson_correlation.all_close(&pearson_correlation.t(), 1e-8)
294+
}
295+
296+
fn constant_random_variables_have_nan_correlation(value: f64) -> bool {
297+
let n_random_variables = 3;
298+
let n_observations = 4;
299+
let a = Array::from_elem((n_random_variables, n_observations), value);
300+
let pearson_correlation = a.pearson_correlation();
301+
pearson_correlation.iter().map(|x| x.is_nan()).fold(true, |acc, flag| acc & flag)
302+
}
303+
}
304+
305+
#[test]
306+
fn test_zero_variables() {
307+
let a = Array2::<f32>::zeros((0, 2));
308+
let pearson_correlation = a.pearson_correlation();
309+
assert_eq!(pearson_correlation.shape(), &[0, 0]);
310+
}
311+
312+
#[test]
313+
fn test_zero_observations() {
314+
let a = Array2::<f32>::zeros((2, 0));
315+
let pearson = a.pearson_correlation();
316+
pearson.mapv(|x| x.is_nan());
317+
}
318+
319+
#[test]
320+
fn test_zero_variables_zero_observations() {
321+
let a = Array2::<f32>::zeros((0, 0));
322+
let pearson = a.pearson_correlation();
323+
assert_eq!(pearson.shape(), &[0, 0]);
324+
}
325+
326+
#[test]
327+
fn test_for_random_array() {
328+
let a = array![
329+
[0.16351516, 0.56863268, 0.16924196, 0.72579120],
330+
[0.44342453, 0.19834387, 0.25411802, 0.62462382],
331+
[0.97162731, 0.29958849, 0.17338142, 0.80198342],
332+
[0.91727132, 0.79817799, 0.62237124, 0.38970998],
333+
[0.26979716, 0.20887228, 0.95454999, 0.96290785]
334+
];
335+
let numpy_corrcoeff = array![
336+
[ 1. , 0.38089376, 0.08122504, -0.59931623, 0.1365648 ],
337+
[ 0.38089376, 1. , 0.80918429, -0.52615195, 0.38954398],
338+
[ 0.08122504, 0.80918429, 1. , 0.07134906, -0.17324776],
339+
[-0.59931623, -0.52615195, 0.07134906, 1. , -0.8743213 ],
340+
[ 0.1365648 , 0.38954398, -0.17324776, -0.8743213 , 1. ]
341+
];
342+
assert_eq!(a.ndim(), 2);
343+
assert!(
344+
a.pearson_correlation().all_close(
345+
&numpy_corrcoeff,
346+
1e-7
347+
)
348+
);
349+
}
350+
351+
}

0 commit comments

Comments
 (0)