Skip to content

Commit 9de552f

Browse files
committed
Add support for inserting new axes while slicing
1 parent 01c3d32 commit 9de552f

File tree

7 files changed

+173
-69
lines changed

7 files changed

+173
-69
lines changed

src/dimension/mod.rs

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,11 @@ pub fn slices_intersect<D: Dimension>(
599599
indices2: &impl CanSlice<D>,
600600
) -> bool {
601601
debug_assert_eq!(indices1.in_ndim(), indices2.in_ndim());
602-
for (&axis_len, &si1, &si2) in izip!(dim.slice(), indices1.as_ref(), indices2.as_ref()) {
602+
for (&axis_len, &si1, &si2) in izip!(
603+
dim.slice(),
604+
indices1.as_ref().iter().filter(|si| !si.is_new_axis()),
605+
indices2.as_ref().iter().filter(|si| !si.is_new_axis()),
606+
) {
603607
// The slices do not intersect iff any pair of `AxisSliceInfo` does not intersect.
604608
match (si1, si2) {
605609
(
@@ -647,6 +651,7 @@ pub fn slices_intersect<D: Dimension>(
647651
return false;
648652
}
649653
}
654+
(AxisSliceInfo::NewAxis, _) | (_, AxisSliceInfo::NewAxis) => unreachable!(),
650655
}
651656
}
652657
true
@@ -688,7 +693,7 @@ mod test {
688693
};
689694
use crate::error::{from_kind, ErrorKind};
690695
use crate::slice::Slice;
691-
use crate::{Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn};
696+
use crate::{Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn, NewAxis};
692697
use num_integer::gcd;
693698
use quickcheck::{quickcheck, TestResult};
694699

@@ -962,17 +967,45 @@ mod test {
962967

963968
#[test]
964969
fn slices_intersect_true() {
965-
assert!(slices_intersect(&Dim([4, 5]), s![.., ..], s![.., ..]));
966-
assert!(slices_intersect(&Dim([4, 5]), s![0, ..], s![0, ..]));
967-
assert!(slices_intersect(&Dim([4, 5]), s![..;2, ..], s![..;3, ..]));
968-
assert!(slices_intersect(&Dim([4, 5]), s![.., ..;2], s![.., 1..;3]));
970+
assert!(slices_intersect(
971+
&Dim([4, 5]),
972+
s![NewAxis, .., NewAxis, ..],
973+
s![.., NewAxis, .., NewAxis]
974+
));
975+
assert!(slices_intersect(
976+
&Dim([4, 5]),
977+
s![NewAxis, 0, ..],
978+
s![0, ..]
979+
));
980+
assert!(slices_intersect(
981+
&Dim([4, 5]),
982+
s![..;2, ..],
983+
s![..;3, NewAxis, ..]
984+
));
985+
assert!(slices_intersect(
986+
&Dim([4, 5]),
987+
s![.., ..;2],
988+
s![.., 1..;3, NewAxis]
989+
));
969990
assert!(slices_intersect(&Dim([4, 10]), s![.., ..;9], s![.., 3..;6]));
970991
}
971992

972993
#[test]
973994
fn slices_intersect_false() {
974-
assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;2, ..]));
975-
assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;3, ..]));
976-
assert!(!slices_intersect(&Dim([4, 5]), s![.., ..;9], s![.., 3..;6]));
995+
assert!(!slices_intersect(
996+
&Dim([4, 5]),
997+
s![..;2, ..],
998+
s![NewAxis, 1..;2, ..]
999+
));
1000+
assert!(!slices_intersect(
1001+
&Dim([4, 5]),
1002+
s![..;2, NewAxis, ..],
1003+
s![1..;3, ..]
1004+
));
1005+
assert!(!slices_intersect(
1006+
&Dim([4, 5]),
1007+
s![.., ..;9],
1008+
s![.., 3..;6, NewAxis]
1009+
));
9771010
}
9781011
}

src/doc/ndarray_for_numpy_users/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@
532532
//! `a[:] = b` | [`a.assign(&b)`][.assign()] | copy the data from array `b` into array `a`
533533
//! `np.concatenate((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), &[a.view(), b.view()])`][stack()] | concatenate arrays `a` and `b` along axis 1
534534
//! `np.stack((a,b), axis=1)` | [`stack_new_axis![Axis(1), a, b]`][stack_new_axis!] or [`stack_new_axis(Axis(1), vec![a.view(), b.view()])`][stack_new_axis()] | stack arrays `a` and `b` along axis 1
535-
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.insert_axis(Axis(1))`][.insert_axis()] | create an array from `a`, inserting a new axis 1
535+
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.slice(s![.., NewAxis])`][.slice()] or [`a.insert_axis(Axis(1))`][.insert_axis()] | create an view of 1-D array `a`, inserting a new axis 1
536536
//! `a.transpose()` or `a.T` | [`a.t()`][.t()] or [`a.reversed_axes()`][.reversed_axes()] | transpose of array `a` (view for `.t()` or by-move for `.reversed_axes()`)
537537
//! `np.diag(a)` | [`a.diag()`][.diag()] | view the diagonal of `a`
538538
//! `a.flatten()` | [`use std::iter::FromIterator; Array::from_iter(a.iter().cloned())`][::from_iter()] | create a 1-D array by flattening `a`

src/impl_methods.rs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,12 @@ where
435435
// Skip the old axis since it should be removed.
436436
old_axis += 1;
437437
}
438+
AxisSliceInfo::NewAxis => {
439+
// Set the dim and stride of the new axis.
440+
new_dim[new_axis] = 1;
441+
new_strides[new_axis] = 0;
442+
new_axis += 1;
443+
}
438444
});
439445
debug_assert_eq!(old_axis, self.ndim());
440446
debug_assert_eq!(new_axis, out_ndim);
@@ -448,6 +454,8 @@ where
448454

449455
/// Slice the array in place without changing the number of dimensions.
450456
///
457+
/// Note that `NewAxis` elements in `info` are ignored.
458+
///
451459
/// See [*Slicing*](#slicing) for full documentation.
452460
///
453461
/// **Panics** if an index is out of bounds or step size is zero.<br>
@@ -461,18 +469,20 @@ where
461469
self.ndim(),
462470
"The input dimension of `info` must match the array to be sliced.",
463471
);
464-
info.as_ref()
465-
.iter()
466-
.enumerate()
467-
.for_each(|(axis, &ax_info)| match ax_info {
472+
let mut axis = 0;
473+
info.as_ref().iter().for_each(|&ax_info| match ax_info {
468474
AxisSliceInfo::Slice { start, end, step } => {
469-
self.slice_axis_inplace(Axis(axis), Slice { start, end, step })
475+
self.slice_axis_inplace(Axis(axis), Slice { start, end, step });
476+
axis += 1;
470477
}
471478
AxisSliceInfo::Index(index) => {
472479
let i_usize = abs_index(self.len_of(Axis(axis)), index);
473-
self.collapse_axis(Axis(axis), i_usize)
480+
self.collapse_axis(Axis(axis), i_usize);
481+
axis += 1;
474482
}
483+
AxisSliceInfo::NewAxis => {}
475484
});
485+
debug_assert_eq!(axis, self.ndim());
476486
}
477487

478488
/// Return a view of the array, sliced along the specified axis.

src/lib.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ pub use crate::dimension::IxDynImpl;
140140
pub use crate::dimension::NdIndex;
141141
pub use crate::error::{ErrorKind, ShapeError};
142142
pub use crate::indexes::{indices, indices_of};
143-
pub use crate::slice::{AxisSliceInfo, Slice, SliceInfo, SliceNextInDim, SliceNextOutDim};
143+
pub use crate::slice::{AxisSliceInfo, NewAxis, Slice, SliceInfo, SliceNextInDim, SliceNextOutDim};
144144

145145
use crate::iterators::Baseiter;
146146
use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, LanesMut};
@@ -494,14 +494,16 @@ pub type Ixs = isize;
494494
///
495495
/// If a range is used, the axis is preserved. If an index is used, that index
496496
/// is selected and the axis is removed; this selects a subview. See
497-
/// [*Subviews*](#subviews) for more information about subviews. Note that
498-
/// [`.slice_collapse()`] behaves like [`.collapse_axis()`] by preserving
499-
/// the number of dimensions.
497+
/// [*Subviews*](#subviews) for more information about subviews. If a
498+
/// [`NewAxis`] instance is used, a new axis is inserted. Note that
499+
/// [`.slice_collapse()`] ignores `NewAxis` elements and behaves like
500+
/// [`.collapse_axis()`] by preserving the number of dimensions.
500501
///
501502
/// [`.slice()`]: #method.slice
502503
/// [`.slice_mut()`]: #method.slice_mut
503504
/// [`.slice_move()`]: #method.slice_move
504505
/// [`.slice_collapse()`]: #method.slice_collapse
506+
/// [`NewAxis`]: struct.NewAxis.html
505507
///
506508
/// It's possible to take multiple simultaneous *mutable* slices with
507509
/// [`.multi_slice_mut()`] or (for [`ArrayViewMut`] only)
@@ -512,7 +514,7 @@ pub type Ixs = isize;
512514
///
513515
/// ```
514516
///
515-
/// use ndarray::{arr2, arr3, s};
517+
/// use ndarray::{arr2, arr3, s, NewAxis};
516518
///
517519
/// // 2 submatrices of 2 rows with 3 elements per row, means a shape of `[2, 2, 3]`.
518520
///
@@ -547,16 +549,17 @@ pub type Ixs = isize;
547549
/// assert_eq!(d, e);
548550
/// assert_eq!(d.shape(), &[2, 1, 3]);
549551
///
550-
/// // Let’s create a slice while selecting a subview with
552+
/// // Let’s create a slice while selecting a subview and inserting a new axis with
551553
/// //
552554
/// // - Both submatrices of the greatest dimension: `..`
553555
/// // - The last row in each submatrix, removing that axis: `-1`
554556
/// // - Row elements in reverse order: `..;-1`
555-
/// let f = a.slice(s![.., -1, ..;-1]);
556-
/// let g = arr2(&[[ 6, 5, 4],
557-
/// [12, 11, 10]]);
557+
/// // - A new axis at the end.
558+
/// let f = a.slice(s![.., -1, ..;-1, NewAxis]);
559+
/// let g = arr3(&[[ [6], [5], [4]],
560+
/// [[12], [11], [10]]]);
558561
/// assert_eq!(f, g);
559-
/// assert_eq!(f.shape(), &[2, 3]);
562+
/// assert_eq!(f.shape(), &[2, 3, 1]);
560563
///
561564
/// // Let's take two disjoint, mutable slices of a matrix with
562565
/// //

src/prelude.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ pub use crate::{array, azip, s};
4949
#[doc(no_inline)]
5050
pub use crate::ShapeBuilder;
5151

52+
#[doc(no_inline)]
53+
pub use crate::NewAxis;
54+
5255
#[doc(no_inline)]
5356
pub use crate::AsArray;
5457

0 commit comments

Comments
 (0)