Skip to content

Add an iter_axis function to get subviews along an axis #57

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 16, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,19 +445,21 @@ pub struct OuterIterCore<A, D> {
ptr: *mut A,
}

fn new_outer_core<A, S, D>(v: ArrayBase<S, D>) -> OuterIterCore<A, D::Smaller>
fn new_outer_core<A, S, D>(v: ArrayBase<S, D>,
axis: usize
) -> OuterIterCore<A, D::Smaller>
where D: RemoveAxis,
S: Data<Elem=A>,
{
let shape = v.shape()[0];
let stride = v.strides()[0];
let shape = v.shape()[axis];
let stride = v.strides()[axis];

OuterIterCore {
index: 0,
len: shape,
stride: stride,
inner_dim: v.dim.remove_axis(0),
inner_strides: v.strides.remove_axis(0),
inner_dim: v.dim.remove_axis(axis),
inner_strides: v.strides.remove_axis(axis),
ptr: v.ptr,
}
}
Expand Down Expand Up @@ -554,11 +556,21 @@ pub fn new_outer_iter<A, D>(v: ArrayView<A, D>) -> OuterIter<A, D::Smaller>
where D: RemoveAxis,
{
OuterIter {
iter: new_outer_core(v),
iter: new_outer_core(v, 0),
life: PhantomData,
}
}

pub fn new_axis_iter<A, D>(v: ArrayView<A, D>, axis: usize) -> OuterIter<A, D::Smaller>
where D: RemoveAxis,
{
OuterIter {
iter: new_outer_core(v, axis),
life: PhantomData,
}
}


/// An iterator that traverses over the outermost dimension
/// and yields each subview (mutable).
///
Expand Down Expand Up @@ -614,7 +626,18 @@ pub fn new_outer_iter_mut<A, D>(v: ArrayViewMut<A, D>) -> OuterIterMut<A, D::Sma
where D: RemoveAxis,
{
OuterIterMut {
iter: new_outer_core(v),
iter: new_outer_core(v, 0),
life: PhantomData,
}
}

pub fn new_axis_iter_mut<A, D>(v: ArrayViewMut<A, D>,
axis: usize
) -> OuterIterMut<A, D::Smaller>
where D: RemoveAxis,
{
OuterIterMut {
iter: new_outer_core(v, axis),
life: PhantomData,
}
}
34 changes: 34 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,24 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
iterators::new_outer_iter(self.view())
}

/// Return an iterator that traverses over the `axis` dimension
/// and yields each subview.
///
/// For example, in a 2 × 2 × 3 array, with `axis` equal to 1,
/// the iterator element
/// is a 2 × 2 subview (and there are 3 in total).
///
/// Iterator element is `ArrayView<A, D::Smaller>` (read-only array view).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should say that it panics if axis is out of bounds

///
/// # Panics
///
/// If axis is out of bounds.
pub fn axis_iter(&self, axis: usize) -> OuterIter<A, D::Smaller>
where D: RemoveAxis
{
iterators::new_axis_iter(self.view(), axis)
}

/// Return an iterator that traverses over the outermost dimension
/// and yields each subview.
///
Expand All @@ -1229,6 +1247,22 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
iterators::new_outer_iter_mut(self.view_mut())
}

/// Return an iterator that traverses over the `axis` dimension
/// and yields each mutable subview.
///
/// Iterator element is `ArrayViewMut<A, D::Smaller>`
/// (read-write array view).
///
/// # Panics
///
/// If axis is out of bounds.
pub fn axis_iter_mut(&mut self, axis: usize) -> OuterIterMut<A, D::Smaller>
where S: DataMut,
D: RemoveAxis,
{
iterators::new_axis_iter_mut(self.view_mut(), axis)
}

// Return (length, stride) for diagonal
fn diag_params(&self) -> (Ix, Ixs)
{
Expand Down
42 changes: 42 additions & 0 deletions tests/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use ndarray::{
Data,
Dimension,
aview1,
arr3,
};

use itertools::assert_equal;
Expand Down Expand Up @@ -198,6 +199,22 @@ fn outer_iter() {
assert_eq!(&found_rows, &found_rows_rev);
}

#[test]
fn axis_iter() {
let a = Array::from_iter(0..12);
let a = a.reshape((2, 3, 2));
// [[[0, 1],
// [2, 3],
// [4, 5]],
// [[6, 7],
// [8, 9],
// ...
assert_equal(a.axis_iter(1),
vec![a.subview(1, 0),
a.subview(1, 1),
a.subview(1, 2)]);
}

#[test]
fn outer_iter_corner_cases() {
let a2 = Array::<i32, _>::zeros((0, 3));
Expand Down Expand Up @@ -234,6 +251,31 @@ fn outer_iter_mut() {
assert_equal(a.inner_iter(), found_rows);
}

#[test]
fn axis_iter_mut() {
let a = Array::from_iter(0..12);
let a = a.reshape((2, 3, 2));
// [[[0, 1],
// [2, 3],
// [4, 5]],
// [[6, 7],
// [8, 9],
// ...
let mut a = a.to_owned();

for mut subview in a.axis_iter_mut(1) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mut subview; it's sad quirks like this because we have no user-defined DSTs (views should be unsized like slices I think).

subview[[0, 0]] = 42;
}

let b = arr3(&[[[42, 1],
[42, 3],
[42, 5]],
[[6, 7],
[8, 9],
[10, 11]]]);
assert_eq!(a, b);
}

#[test]
fn outer_iter_size_hint() {
// Check that the size hint is correctly computed
Expand Down