Skip to content

Commit 52e145f

Browse files
committed
Merge pull request #57 from vbarrielle/iter_axis
Add an iter_axis function to get subviews along an axis
2 parents 337d758 + b509aa7 commit 52e145f

File tree

3 files changed

+106
-7
lines changed

3 files changed

+106
-7
lines changed

src/iterators.rs

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -445,19 +445,21 @@ pub struct OuterIterCore<A, D> {
445445
ptr: *mut A,
446446
}
447447

448-
fn new_outer_core<A, S, D>(v: ArrayBase<S, D>) -> OuterIterCore<A, D::Smaller>
448+
fn new_outer_core<A, S, D>(v: ArrayBase<S, D>,
449+
axis: usize
450+
) -> OuterIterCore<A, D::Smaller>
449451
where D: RemoveAxis,
450452
S: Data<Elem=A>,
451453
{
452-
let shape = v.shape()[0];
453-
let stride = v.strides()[0];
454+
let shape = v.shape()[axis];
455+
let stride = v.strides()[axis];
454456

455457
OuterIterCore {
456458
index: 0,
457459
len: shape,
458460
stride: stride,
459-
inner_dim: v.dim.remove_axis(0),
460-
inner_strides: v.strides.remove_axis(0),
461+
inner_dim: v.dim.remove_axis(axis),
462+
inner_strides: v.strides.remove_axis(axis),
461463
ptr: v.ptr,
462464
}
463465
}
@@ -554,11 +556,21 @@ pub fn new_outer_iter<A, D>(v: ArrayView<A, D>) -> OuterIter<A, D::Smaller>
554556
where D: RemoveAxis,
555557
{
556558
OuterIter {
557-
iter: new_outer_core(v),
559+
iter: new_outer_core(v, 0),
558560
life: PhantomData,
559561
}
560562
}
561563

564+
pub fn new_axis_iter<A, D>(v: ArrayView<A, D>, axis: usize) -> OuterIter<A, D::Smaller>
565+
where D: RemoveAxis,
566+
{
567+
OuterIter {
568+
iter: new_outer_core(v, axis),
569+
life: PhantomData,
570+
}
571+
}
572+
573+
562574
/// An iterator that traverses over the outermost dimension
563575
/// and yields each subview (mutable).
564576
///
@@ -614,7 +626,18 @@ pub fn new_outer_iter_mut<A, D>(v: ArrayViewMut<A, D>) -> OuterIterMut<A, D::Sma
614626
where D: RemoveAxis,
615627
{
616628
OuterIterMut {
617-
iter: new_outer_core(v),
629+
iter: new_outer_core(v, 0),
630+
life: PhantomData,
631+
}
632+
}
633+
634+
pub fn new_axis_iter_mut<A, D>(v: ArrayViewMut<A, D>,
635+
axis: usize
636+
) -> OuterIterMut<A, D::Smaller>
637+
where D: RemoveAxis,
638+
{
639+
OuterIterMut {
640+
iter: new_outer_core(v, axis),
618641
life: PhantomData,
619642
}
620643
}

src/lib.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,24 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
12181218
iterators::new_outer_iter(self.view())
12191219
}
12201220

1221+
/// Return an iterator that traverses over the `axis` dimension
1222+
/// and yields each subview.
1223+
///
1224+
/// For example, in a 2 × 2 × 3 array, with `axis` equal to 1,
1225+
/// the iterator element
1226+
/// is a 2 × 2 subview (and there are 3 in total).
1227+
///
1228+
/// Iterator element is `ArrayView<A, D::Smaller>` (read-only array view).
1229+
///
1230+
/// # Panics
1231+
///
1232+
/// If axis is out of bounds.
1233+
pub fn axis_iter(&self, axis: usize) -> OuterIter<A, D::Smaller>
1234+
where D: RemoveAxis
1235+
{
1236+
iterators::new_axis_iter(self.view(), axis)
1237+
}
1238+
12211239
/// Return an iterator that traverses over the outermost dimension
12221240
/// and yields each subview.
12231241
///
@@ -1229,6 +1247,22 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
12291247
iterators::new_outer_iter_mut(self.view_mut())
12301248
}
12311249

1250+
/// Return an iterator that traverses over the `axis` dimension
1251+
/// and yields each mutable subview.
1252+
///
1253+
/// Iterator element is `ArrayViewMut<A, D::Smaller>`
1254+
/// (read-write array view).
1255+
///
1256+
/// # Panics
1257+
///
1258+
/// If axis is out of bounds.
1259+
pub fn axis_iter_mut(&mut self, axis: usize) -> OuterIterMut<A, D::Smaller>
1260+
where S: DataMut,
1261+
D: RemoveAxis,
1262+
{
1263+
iterators::new_axis_iter_mut(self.view_mut(), axis)
1264+
}
1265+
12321266
// Return (length, stride) for diagonal
12331267
fn diag_params(&self) -> (Ix, Ixs)
12341268
{

tests/iterators.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use ndarray::{
99
Data,
1010
Dimension,
1111
aview1,
12+
arr3,
1213
};
1314

1415
use itertools::assert_equal;
@@ -216,6 +217,22 @@ fn outer_iter() {
216217
assert_equal(a.inner_iter(), found_rows);
217218
}
218219

220+
#[test]
221+
fn axis_iter() {
222+
let a = Array::from_iter(0..12);
223+
let a = a.reshape((2, 3, 2));
224+
// [[[0, 1],
225+
// [2, 3],
226+
// [4, 5]],
227+
// [[6, 7],
228+
// [8, 9],
229+
// ...
230+
assert_equal(a.axis_iter(1),
231+
vec![a.subview(1, 0),
232+
a.subview(1, 1),
233+
a.subview(1, 2)]);
234+
}
235+
219236
#[test]
220237
fn outer_iter_corner_cases() {
221238
let a2 = Array::<i32, _>::zeros((0, 3));
@@ -252,6 +269,31 @@ fn outer_iter_mut() {
252269
assert_equal(a.inner_iter(), found_rows);
253270
}
254271

272+
#[test]
273+
fn axis_iter_mut() {
274+
let a = Array::from_iter(0..12);
275+
let a = a.reshape((2, 3, 2));
276+
// [[[0, 1],
277+
// [2, 3],
278+
// [4, 5]],
279+
// [[6, 7],
280+
// [8, 9],
281+
// ...
282+
let mut a = a.to_owned();
283+
284+
for mut subview in a.axis_iter_mut(1) {
285+
subview[[0, 0]] = 42;
286+
}
287+
288+
let b = arr3(&[[[42, 1],
289+
[42, 3],
290+
[42, 5]],
291+
[[6, 7],
292+
[8, 9],
293+
[10, 11]]]);
294+
assert_eq!(a, b);
295+
}
296+
255297
#[test]
256298
fn outer_iter_size_hint() {
257299
// Check that the size hint is correctly computed

0 commit comments

Comments
 (0)