@@ -2,6 +2,8 @@ use ndarray::prelude::*;
2
2
use ndarray:: Data ;
3
3
use num_traits:: { Float , FromPrimitive } ;
4
4
5
+ /// Extension trait for ArrayBase providing functions
6
+ /// to compute different correlation measures.
5
7
pub trait CorrelationExt < A , S >
6
8
where
7
9
S : Data < Elem = A > ,
@@ -11,13 +13,13 @@ where
11
13
///
12
14
/// Let `(r, o)` be the shape of `M`:
13
15
/// - `r` is the number of random variables;
14
- /// - `o` is the number of observations we have collected
16
+ /// - `o` is the number of observations we have collected
15
17
/// for each random variable.
16
- ///
17
- /// Every column in `M` is an experiment: a single observation for each
18
+ ///
19
+ /// Every column in `M` is an experiment: a single observation for each
18
20
/// random variable.
19
21
/// Each row in `M` contains all the observations for a certain random variable.
20
- ///
22
+ ///
21
23
/// The parameter `ddof` specifies the "delta degrees of freedom". For
22
24
/// example, to calculate the population covariance, use `ddof = 0`, or to
23
25
/// calculate the sample covariance (unbiased estimate), use `ddof = 1`.
37
39
/// x̅ = ― ∑ xᵢ
38
40
/// n i=1
39
41
/// ```
40
- /// and similarly for ̅y.
42
+ /// and similarly for ̅y.
41
43
///
42
44
/// **Panics** if `ddof` is greater than or equal to the number of
43
45
/// observations, if the number of observations is zero and division by
@@ -56,11 +58,65 @@ where
56
58
/// [2., 4., 6.]]);
57
59
/// let covariance = a.cov(1.);
58
60
/// assert_eq!(
59
- /// covariance,
61
+ /// covariance,
60
62
/// aview2(&[[4., 4.], [4., 4.]])
61
63
/// );
62
64
/// ```
63
- fn cov ( & self , ddof : A ) -> Array2 < A >
65
+ fn cov ( & self , ddof : A ) -> Array2 < A >
66
+ where
67
+ A : Float + FromPrimitive ;
68
+
69
+ /// Return the [Pearson correlation coefficients](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient)
70
+ /// for a 2-dimensional array of observations `M`.
71
+ ///
72
+ /// Let `(r, o)` be the shape of `M`:
73
+ /// - `r` is the number of random variables;
74
+ /// - `o` is the number of observations we have collected
75
+ /// for each random variable.
76
+ ///
77
+ /// Every column in `M` is an experiment: a single observation for each
78
+ /// random variable.
79
+ /// Each row in `M` contains all the observations for a certain random variable.
80
+ ///
81
+ /// The Pearson correlation coefficient of two random variables is defined as:
82
+ ///
83
+ /// ```text
84
+ /// cov(X, Y)
85
+ /// rho(X, Y) = ――――――――――――
86
+ /// std(X)std(Y)
87
+ /// ```
88
+ ///
89
+ /// Let `R` be the matrix returned by this function. Then
90
+ /// ```text
91
+ /// R_ij = rho(X_i, X_j)
92
+ /// ```
93
+ ///
94
+ /// **Panics** if `M` is empty, if the type cast of `n_observations`
95
+ /// from `usize` to `A` fails or if the standard deviation of one of the random
96
+ ///
97
+ /// # Example
98
+ ///
99
+ /// variables is zero and division by zero panics for type A.
100
+ /// ```
101
+ /// extern crate ndarray;
102
+ /// extern crate ndarray_stats;
103
+ /// use ndarray::arr2;
104
+ /// use ndarray_stats::CorrelationExt;
105
+ ///
106
+ /// let a = arr2(&[[1., 3., 5.],
107
+ /// [2., 4., 6.]]);
108
+ /// let corr = a.pearson_correlation();
109
+ /// assert!(
110
+ /// corr.all_close(
111
+ /// &arr2(&[
112
+ /// [1., 1.],
113
+ /// [1., 1.],
114
+ /// ]),
115
+ /// 1e-7
116
+ /// )
117
+ /// );
118
+ /// ```
119
+ fn pearson_correlation ( & self ) -> Array2 < A >
64
120
where
65
121
A : Float + FromPrimitive ;
66
122
}
75
131
{
76
132
let observation_axis = Axis ( 1 ) ;
77
133
let n_observations = A :: from_usize ( self . len_of ( observation_axis) ) . unwrap ( ) ;
78
- let dof =
134
+ let dof =
79
135
if ddof >= n_observations {
80
136
panic ! ( "`ddof` needs to be strictly smaller than the \
81
137
number of observations provided for each \
@@ -88,16 +144,33 @@ where
88
144
let covariance = denoised. dot ( & denoised. t ( ) ) ;
89
145
covariance. mapv_into ( |x| x / dof)
90
146
}
147
+
148
+ fn pearson_correlation ( & self ) -> Array2 < A >
149
+ where
150
+ A : Float + FromPrimitive ,
151
+ {
152
+ let observation_axis = Axis ( 1 ) ;
153
+ // The ddof value doesn't matter, as long as we use the same one
154
+ // for computing covariance and standard deviation
155
+ // We choose -1 to avoid panicking when we only have one
156
+ // observation per random variable (or no observations at all)
157
+ let ddof = -A :: one ( ) ;
158
+ let cov = self . cov ( ddof) ;
159
+ let std = self . std_axis ( observation_axis, ddof) . insert_axis ( observation_axis) ;
160
+ let std_matrix = std. dot ( & std. t ( ) ) ;
161
+ // element-wise division
162
+ cov / std_matrix
163
+ }
91
164
}
92
165
93
166
#[ cfg( test) ]
94
- mod tests {
167
+ mod cov_tests {
95
168
use super :: * ;
96
169
use rand;
97
170
use rand:: distributions:: Range ;
98
171
use ndarray_rand:: RandomExt ;
99
172
100
- quickcheck ! {
173
+ quickcheck ! {
101
174
fn constant_random_variables_have_zero_covariance_matrix( value: f64 ) -> bool {
102
175
let n_random_variables = 3 ;
103
176
let n_observations = 4 ;
@@ -112,21 +185,21 @@ mod tests {
112
185
let n_random_variables = 3 ;
113
186
let n_observations = 4 ;
114
187
let a = Array :: random(
115
- ( n_random_variables, n_observations) ,
188
+ ( n_random_variables, n_observations) ,
116
189
Range :: new( -bound. abs( ) , bound. abs( ) )
117
190
) ;
118
191
let covariance = a. cov( 1. ) ;
119
192
covariance. all_close( & covariance. t( ) , 1e-8 )
120
193
}
121
194
}
122
-
195
+
123
196
#[ test]
124
197
#[ should_panic]
125
198
fn test_invalid_ddof ( ) {
126
199
let n_random_variables = 3 ;
127
200
let n_observations = 4 ;
128
201
let a = Array :: random (
129
- ( n_random_variables, n_observations) ,
202
+ ( n_random_variables, n_observations) ,
130
203
Range :: new ( 0. , 10. )
131
204
) ;
132
205
let invalid_ddof = ( n_observations as f64 ) + rand:: random :: < f64 > ( ) . abs ( ) ;
@@ -200,4 +273,79 @@ mod tests {
200
273
)
201
274
) ;
202
275
}
203
- }
276
+ }
277
+
278
+ #[ cfg( test) ]
279
+ mod pearson_correlation_tests {
280
+ use super :: * ;
281
+ use rand:: distributions:: Range ;
282
+ use ndarray_rand:: RandomExt ;
283
+
284
+ quickcheck ! {
285
+ fn output_matrix_is_symmetric( bound: f64 ) -> bool {
286
+ let n_random_variables = 3 ;
287
+ let n_observations = 4 ;
288
+ let a = Array :: random(
289
+ ( n_random_variables, n_observations) ,
290
+ Range :: new( -bound. abs( ) , bound. abs( ) )
291
+ ) ;
292
+ let pearson_correlation = a. pearson_correlation( ) ;
293
+ pearson_correlation. all_close( & pearson_correlation. t( ) , 1e-8 )
294
+ }
295
+
296
+ fn constant_random_variables_have_nan_correlation( value: f64 ) -> bool {
297
+ let n_random_variables = 3 ;
298
+ let n_observations = 4 ;
299
+ let a = Array :: from_elem( ( n_random_variables, n_observations) , value) ;
300
+ let pearson_correlation = a. pearson_correlation( ) ;
301
+ pearson_correlation. iter( ) . map( |x| x. is_nan( ) ) . fold( true , |acc, flag| acc & flag)
302
+ }
303
+ }
304
+
305
+ #[ test]
306
+ fn test_zero_variables ( ) {
307
+ let a = Array2 :: < f32 > :: zeros ( ( 0 , 2 ) ) ;
308
+ let pearson_correlation = a. pearson_correlation ( ) ;
309
+ assert_eq ! ( pearson_correlation. shape( ) , & [ 0 , 0 ] ) ;
310
+ }
311
+
312
+ #[ test]
313
+ fn test_zero_observations ( ) {
314
+ let a = Array2 :: < f32 > :: zeros ( ( 2 , 0 ) ) ;
315
+ let pearson = a. pearson_correlation ( ) ;
316
+ pearson. mapv ( |x| x. is_nan ( ) ) ;
317
+ }
318
+
319
+ #[ test]
320
+ fn test_zero_variables_zero_observations ( ) {
321
+ let a = Array2 :: < f32 > :: zeros ( ( 0 , 0 ) ) ;
322
+ let pearson = a. pearson_correlation ( ) ;
323
+ assert_eq ! ( pearson. shape( ) , & [ 0 , 0 ] ) ;
324
+ }
325
+
326
+ #[ test]
327
+ fn test_for_random_array ( ) {
328
+ let a = array ! [
329
+ [ 0.16351516 , 0.56863268 , 0.16924196 , 0.72579120 ] ,
330
+ [ 0.44342453 , 0.19834387 , 0.25411802 , 0.62462382 ] ,
331
+ [ 0.97162731 , 0.29958849 , 0.17338142 , 0.80198342 ] ,
332
+ [ 0.91727132 , 0.79817799 , 0.62237124 , 0.38970998 ] ,
333
+ [ 0.26979716 , 0.20887228 , 0.95454999 , 0.96290785 ]
334
+ ] ;
335
+ let numpy_corrcoeff = array ! [
336
+ [ 1. , 0.38089376 , 0.08122504 , -0.59931623 , 0.1365648 ] ,
337
+ [ 0.38089376 , 1. , 0.80918429 , -0.52615195 , 0.38954398 ] ,
338
+ [ 0.08122504 , 0.80918429 , 1. , 0.07134906 , -0.17324776 ] ,
339
+ [ -0.59931623 , -0.52615195 , 0.07134906 , 1. , -0.8743213 ] ,
340
+ [ 0.1365648 , 0.38954398 , -0.17324776 , -0.8743213 , 1. ]
341
+ ] ;
342
+ assert_eq ! ( a. ndim( ) , 2 ) ;
343
+ assert ! (
344
+ a. pearson_correlation( ) . all_close(
345
+ & numpy_corrcoeff,
346
+ 1e-7
347
+ )
348
+ ) ;
349
+ }
350
+
351
+ }
0 commit comments