Skip to content

Commit ee588fb

Browse files
authored
Merge pull request #459 from jturner314/improve-var-axis
Improve docs, add tests, and use mul_add in var_axis
2 parents 6d9a44d + 219485d commit ee588fb

File tree

2 files changed

+70
-7
lines changed

2 files changed

+70
-7
lines changed

src/numeric/impl_numeric.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,18 +140,19 @@ impl<A, S, D> ArrayBase<S, D>
140140
/// n i=1
141141
/// ```
142142
///
143-
/// **Panics** if `ddof` is greater equal than the length of `axis`.
144-
/// **Panics** if `axis` is out of bounds or if length of `axis` is zero.
143+
/// **Panics** if `ddof` is greater than or equal to the length of the
144+
/// axis, if `axis` is out of bounds, or if the length of the axis is zero.
145145
///
146146
/// # Example
147147
///
148148
/// ```
149149
/// use ndarray::{aview1, arr2, Axis};
150150
///
151151
/// let a = arr2(&[[1., 2.],
152-
/// [3., 4.]]);
153-
/// let var = a.var_axis(Axis(0), 0.);
154-
/// assert_eq!(var, aview1(&[1., 1.]));
152+
/// [3., 4.],
153+
/// [5., 6.]]);
154+
/// let var = a.var_axis(Axis(0), 1.);
155+
/// assert_eq!(var, aview1(&[4., 4.]));
155156
/// ```
156157
pub fn var_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
157158
where
@@ -166,11 +167,11 @@ impl<A, S, D> ArrayBase<S, D>
166167
azip!(mut mean, mut sum_sq, x (subview) in {
167168
let delta = x - *mean;
168169
*mean = *mean + delta / count;
169-
*sum_sq = *sum_sq + delta * (x - *mean);
170+
*sum_sq = (x - *mean).mul_add(delta, *sum_sq);
170171
});
171172
}
172173
if ddof >= count {
173-
panic!("Ddof needs to be strictly smaller than the length \
174+
panic!("`ddof` needs to be strictly smaller than the length \
174175
of the axis you are computing the variance for!")
175176
} else {
176177
let dof = count - ddof;

tests/array.rs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,68 @@ fn sum_mean()
689689
assert_eq!(a.scalar_sum(), 10.);
690690
}
691691

692+
#[test]
693+
fn var_axis() {
694+
let a = array![
695+
[
696+
[-9.76, -0.38, 1.59, 6.23],
697+
[-8.57, -9.27, 5.76, 6.01],
698+
[-9.54, 5.09, 3.21, 6.56],
699+
],
700+
[
701+
[ 8.23, -9.63, 3.76, -3.48],
702+
[-5.46, 5.86, -2.81, 1.35],
703+
[-1.08, 4.66, 8.34, -0.73],
704+
],
705+
];
706+
assert!(a.var_axis(Axis(0), 1.5).all_close(
707+
&aview2(&[
708+
[3.236401e+02, 8.556250e+01, 4.708900e+00, 9.428410e+01],
709+
[9.672100e+00, 2.289169e+02, 7.344490e+01, 2.171560e+01],
710+
[7.157160e+01, 1.849000e-01, 2.631690e+01, 5.314410e+01]
711+
]),
712+
1e-4,
713+
));
714+
assert!(a.var_axis(Axis(1), 1.7).all_close(
715+
&aview2(&[
716+
[0.61676923, 80.81092308, 6.79892308, 0.11789744],
717+
[75.19912821, 114.25235897, 48.32405128, 9.03020513],
718+
]),
719+
1e-8,
720+
));
721+
assert!(a.var_axis(Axis(2), 2.3).all_close(
722+
&aview2(&[
723+
[ 79.64552941, 129.09663235, 95.98929412],
724+
[109.64952941, 43.28758824, 36.27439706],
725+
]),
726+
1e-8,
727+
));
728+
729+
let b = array![[1.1, 2.3, 4.7]];
730+
assert!(b.var_axis(Axis(0), 0.).all_close(&aview1(&[0., 0., 0.]), 1e-12));
731+
assert!(b.var_axis(Axis(1), 0.).all_close(&aview1(&[2.24]), 1e-12));
732+
733+
let c = array![[], []];
734+
assert_eq!(c.var_axis(Axis(0), 0.), aview1(&[]));
735+
736+
let d = array![1.1, 2.7, 3.5, 4.9];
737+
assert!(d.var_axis(Axis(0), 0.).all_close(&aview0(&1.8875), 1e-12));
738+
}
739+
740+
#[test]
741+
#[should_panic]
742+
fn var_axis_bad_dof() {
743+
let a = array![1., 2., 3.];
744+
a.var_axis(Axis(0), 4.);
745+
}
746+
747+
#[test]
748+
#[should_panic]
749+
fn var_axis_empty_axis() {
750+
let a = array![[], []];
751+
a.var_axis(Axis(1), 0.);
752+
}
753+
692754
#[test]
693755
fn iter_size_hint()
694756
{

0 commit comments

Comments
 (0)