Skip to content

Commit b509aa7

Browse files
committed
add axis_iter_mut
1 parent 3b16566 commit b509aa7

File tree

3 files changed

+53
-0
lines changed

3 files changed

+53
-0
lines changed

src/iterators.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,3 +630,14 @@ pub fn new_outer_iter_mut<A, D>(v: ArrayViewMut<A, D>) -> OuterIterMut<A, D::Sma
630630
life: PhantomData,
631631
}
632632
}
633+
634+
pub fn new_axis_iter_mut<A, D>(v: ArrayViewMut<A, D>,
635+
axis: usize
636+
) -> OuterIterMut<A, D::Smaller>
637+
where D: RemoveAxis,
638+
{
639+
OuterIterMut {
640+
iter: new_outer_core(v, axis),
641+
life: PhantomData,
642+
}
643+
}

src/lib.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,22 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
12471247
iterators::new_outer_iter_mut(self.view_mut())
12481248
}
12491249

1250+
/// Return an iterator that traverses over the `axis` dimension
1251+
/// and yields each mutable subview.
1252+
///
1253+
/// Iterator element is `ArrayViewMut<A, D::Smaller>`
1254+
/// (read-write array view).
1255+
///
1256+
/// # Panics
1257+
///
1258+
/// If axis is out of bounds.
1259+
pub fn axis_iter_mut(&mut self, axis: usize) -> OuterIterMut<A, D::Smaller>
1260+
where S: DataMut,
1261+
D: RemoveAxis,
1262+
{
1263+
iterators::new_axis_iter_mut(self.view_mut(), axis)
1264+
}
1265+
12501266
// Return (length, stride) for diagonal
12511267
fn diag_params(&self) -> (Ix, Ixs)
12521268
{

tests/iterators.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use ndarray::{
88
Data,
99
Dimension,
1010
aview1,
11+
arr3,
1112
};
1213

1314
use itertools::assert_equal;
@@ -250,6 +251,31 @@ fn outer_iter_mut() {
250251
assert_equal(a.inner_iter(), found_rows);
251252
}
252253

254+
#[test]
255+
fn axis_iter_mut() {
256+
let a = Array::from_iter(0..12);
257+
let a = a.reshape((2, 3, 2));
258+
// [[[0, 1],
259+
// [2, 3],
260+
// [4, 5]],
261+
// [[6, 7],
262+
// [8, 9],
263+
// ...
264+
let mut a = a.to_owned();
265+
266+
for mut subview in a.axis_iter_mut(1) {
267+
subview[[0, 0]] = 42;
268+
}
269+
270+
let b = arr3(&[[[42, 1],
271+
[42, 3],
272+
[42, 5]],
273+
[[6, 7],
274+
[8, 9],
275+
[10, 11]]]);
276+
assert_eq!(a, b);
277+
}
278+
253279
#[test]
254280
fn outer_iter_size_hint() {
255281
// Check that the size hint is correctly computed

0 commit comments

Comments
 (0)