Skip to content

Commit 17a628e

Browse files
akern40adamreichold
authored andcommitted
Implements and tests product_axis.
1 parent bde682a commit 17a628e

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

src/numeric/impl_numeric.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#[cfg(feature = "std")]
1010
use num_traits::Float;
11+
use num_traits::One;
1112
use num_traits::{FromPrimitive, Zero};
1213
use std::ops::{Add, Div, Mul};
1314

@@ -253,6 +254,43 @@ where
253254
}
254255
}
255256

257+
/// Return product along `axis`.
258+
///
259+
/// The product of an empty array is 1.
260+
///
261+
/// ```
262+
/// use ndarray::{aview0, aview1, arr2, Axis};
263+
///
264+
/// let a = arr2(&[[1., 2., 3.],
265+
/// [4., 5., 6.]]);
266+
///
267+
/// assert!(
268+
/// a.product_axis(Axis(0)) == aview1(&[4., 10., 18.]) &&
269+
/// a.product_axis(Axis(1)) == aview1(&[6., 120.]) &&
270+
///
271+
/// a.product_axis(Axis(0)).product_axis(Axis(0)) == aview0(&720.)
272+
/// );
273+
/// ```
274+
///
275+
/// **Panics** if `axis` is out of bounds.
276+
#[track_caller]
277+
pub fn product_axis(&self, axis: Axis) -> Array<A, D::Smaller>
278+
where
279+
A: Clone + One + Mul<Output = A>,
280+
D: RemoveAxis,
281+
{
282+
let min_stride_axis = self.dim.min_stride_axis(&self.strides);
283+
if axis == min_stride_axis {
284+
crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.product())
285+
} else {
286+
let mut res = Array::ones(self.raw_dim().remove_axis(axis));
287+
for subview in self.axis_iter(axis) {
288+
res = res * &subview;
289+
}
290+
res
291+
}
292+
}
293+
256294
/// Return mean along `axis`.
257295
///
258296
/// Return `None` if the length of the axis is zero.

tests/numeric.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,27 +39,36 @@ fn test_mean_with_array_of_floats()
3939
}
4040

4141
#[test]
42-
fn sum_mean()
42+
fn sum_mean_prod()
4343
{
4444
let a: Array2<f64> = arr2(&[[1., 2.], [3., 4.]]);
4545
assert_eq!(a.sum_axis(Axis(0)), arr1(&[4., 6.]));
4646
assert_eq!(a.sum_axis(Axis(1)), arr1(&[3., 7.]));
47+
assert_eq!(a.product_axis(Axis(0)), arr1(&[3., 8.]));
48+
assert_eq!(a.product_axis(Axis(1)), arr1(&[2., 12.]));
4749
assert_eq!(a.mean_axis(Axis(0)), Some(arr1(&[2., 3.])));
4850
assert_eq!(a.mean_axis(Axis(1)), Some(arr1(&[1.5, 3.5])));
4951
assert_eq!(a.sum_axis(Axis(1)).sum_axis(Axis(0)), arr0(10.));
52+
assert_eq!(a.product_axis(Axis(1)).product_axis(Axis(0)), arr0(24.));
5053
assert_eq!(a.view().mean_axis(Axis(1)).unwrap(), aview1(&[1.5, 3.5]));
5154
assert_eq!(a.sum(), 10.);
5255
}
5356

5457
#[test]
55-
fn sum_mean_empty()
58+
fn sum_mean_prod_empty()
5659
{
5760
assert_eq!(Array3::<f32>::ones((2, 0, 3)).sum(), 0.);
61+
assert_eq!(Array3::<f32>::ones((2, 0, 3)).product(), 1.);
5862
assert_eq!(Array1::<f32>::ones(0).sum_axis(Axis(0)), arr0(0.));
63+
assert_eq!(Array1::<f32>::ones(0).product_axis(Axis(0)), arr0(1.));
5964
assert_eq!(
6065
Array3::<f32>::ones((2, 0, 3)).sum_axis(Axis(1)),
6166
Array::zeros((2, 3)),
6267
);
68+
assert_eq!(
69+
Array3::<f32>::ones((2, 0, 3)).product_axis(Axis(1)),
70+
Array::ones((2, 3)),
71+
);
6372
let a = Array1::<f32>::ones(0).mean_axis(Axis(0));
6473
assert_eq!(a, None);
6574
let a = Array3::<f32>::ones((2, 0, 3)).mean_axis(Axis(1));

0 commit comments

Comments
 (0)