Skip to content

Simplifications for slicing-related types #940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,8 @@ fn slice_min_max(axis_len: usize, slice: Slice) -> Option<(usize, usize)> {
/// Returns `true` iff the slices intersect.
pub fn slices_intersect<D: Dimension>(
dim: &D,
indices1: &impl SliceArg<D>,
indices2: &impl SliceArg<D>,
indices1: impl SliceArg<D>,
indices2: impl SliceArg<D>,
) -> bool {
debug_assert_eq!(indices1.in_ndim(), indices2.in_ndim());
for (&axis_len, &si1, &si2) in izip!(
Expand Down
16 changes: 8 additions & 8 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,9 @@ where
///
/// **Panics** if an index is out of bounds or step size is zero.<br>
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.)
pub fn slice<I>(&self, info: &I) -> ArrayView<'_, A, I::OutDim>
pub fn slice<I>(&self, info: I) -> ArrayView<'_, A, I::OutDim>
where
I: SliceArg<D> + ?Sized,
I: SliceArg<D>,
S: Data,
{
self.view().slice_move(info)
Expand All @@ -353,9 +353,9 @@ where
///
/// **Panics** if an index is out of bounds or step size is zero.<br>
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.)
pub fn slice_mut<I>(&mut self, info: &I) -> ArrayViewMut<'_, A, I::OutDim>
pub fn slice_mut<I>(&mut self, info: I) -> ArrayViewMut<'_, A, I::OutDim>
where
I: SliceArg<D> + ?Sized,
I: SliceArg<D>,
S: DataMut,
{
self.view_mut().slice_move(info)
Expand Down Expand Up @@ -399,9 +399,9 @@ where
///
/// **Panics** if an index is out of bounds or step size is zero.<br>
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.)
pub fn slice_move<I>(mut self, info: &I) -> ArrayBase<S, I::OutDim>
pub fn slice_move<I>(mut self, info: I) -> ArrayBase<S, I::OutDim>
where
I: SliceArg<D> + ?Sized,
I: SliceArg<D>,
{
assert_eq!(
info.in_ndim(),
Expand Down Expand Up @@ -468,9 +468,9 @@ where
/// - if [`AxisSliceInfo::NewAxis`] is in `info`, e.g. if [`NewAxis`] was
/// used in the [`s!`] macro
/// - if `D` is `IxDyn` and `info` does not match the number of array axes
pub fn slice_collapse<I>(&mut self, info: &I)
pub fn slice_collapse<I>(&mut self, info: I)
where
I: SliceArg<D> + ?Sized,
I: SliceArg<D>,
{
assert_eq!(
info.in_ndim(),
Expand Down
7 changes: 3 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,8 @@ pub type Ixs = isize;
///
/// The slicing argument can be passed using the macro [`s![]`](macro.s!.html),
/// which will be used in all examples. (The explicit form is an instance of
/// [`&SliceInfo`]; see its docs for more information.)
///
/// [`&SliceInfo`]: struct.SliceInfo.html
/// [`SliceInfo`] or another type which implements [`SliceArg`]; see their docs
/// for more information.)
///
/// If a range is used, the axis is preserved. If an index is used, that index
/// is selected and the axis is removed; this selects a subview. See
Expand All @@ -510,7 +509,7 @@ pub type Ixs = isize;
/// [`NewAxis`]: struct.NewAxis.html
///
/// When slicing arrays with generic dimensionality, creating an instance of
/// [`&SliceInfo`] to pass to the multi-axis slicing methods like [`.slice()`]
/// [`SliceInfo`] to pass to the multi-axis slicing methods like [`.slice()`]
/// is awkward. In these cases, it's usually more convenient to use
/// [`.slice_each_axis()`]/[`.slice_each_axis_mut()`]/[`.slice_each_axis_inplace()`]
/// or to create a view and then slice individual axes of the view using
Expand Down
88 changes: 40 additions & 48 deletions src/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ pub struct NewAxis;
/// A slice (range with step), an index, or a new axis token.
///
/// See also the [`s![]`](macro.s!.html) macro for a convenient way to create a
/// `&SliceInfo<[AxisSliceInfo; n], Din, Dout>`.
/// `SliceInfo<[AxisSliceInfo; n], Din, Dout>`.
///
/// ## Examples
///
Expand Down Expand Up @@ -324,6 +324,24 @@ pub unsafe trait SliceArg<D: Dimension>: AsRef<[AxisSliceInfo]> {
private_decl! {}
}

unsafe impl<T, D> SliceArg<D> for &T
where
T: SliceArg<D> + ?Sized,
D: Dimension,
{
type OutDim = T::OutDim;

fn in_ndim(&self) -> usize {
T::in_ndim(self)
}

fn out_ndim(&self) -> usize {
T::out_ndim(self)
}

private_impl! {}
}

macro_rules! impl_slicearg_samedim {
($in_dim:ty) => {
unsafe impl<T, Dout> SliceArg<$in_dim> for SliceInfo<T, $in_dim, Dout>
Expand Down Expand Up @@ -388,7 +406,7 @@ unsafe impl SliceArg<IxDyn> for [AxisSliceInfo] {

/// Represents all of the necessary information to perform a slice.
///
/// The type `T` is typically `[AxisSliceInfo; n]`, `[AxisSliceInfo]`, or
/// The type `T` is typically `[AxisSliceInfo; n]`, `&[AxisSliceInfo]`, or
/// `Vec<AxisSliceInfo>`. The type `Din` is the dimension of the array to be
/// sliced, and `Dout` is the output dimension after calling [`.slice()`]. Note
/// that if `Din` is a fixed dimension type (`Ix0`, `Ix1`, `Ix2`, etc.), the
Expand All @@ -397,14 +415,13 @@ unsafe impl SliceArg<IxDyn> for [AxisSliceInfo] {
///
/// [`.slice()`]: struct.ArrayBase.html#method.slice
#[derive(Debug)]
#[repr(transparent)]
pub struct SliceInfo<T: ?Sized, Din: Dimension, Dout: Dimension> {
pub struct SliceInfo<T, Din: Dimension, Dout: Dimension> {
in_dim: PhantomData<Din>,
out_dim: PhantomData<Dout>,
indices: T,
}

impl<T: ?Sized, Din, Dout> Deref for SliceInfo<T, Din, Dout>
impl<T, Din, Dout> Deref for SliceInfo<T, Din, Dout>
where
Din: Dimension,
Dout: Dimension,
Expand Down Expand Up @@ -464,14 +481,7 @@ where
indices,
}
}
}

impl<T, Din, Dout> SliceInfo<T, Din, Dout>
where
T: AsRef<[AxisSliceInfo]>,
Din: Dimension,
Dout: Dimension,
{
/// Returns a new `SliceInfo` instance.
///
/// Errors if `Din` or `Dout` is not consistent with `indices`.
Expand All @@ -490,14 +500,7 @@ where
indices,
})
}
}

impl<T: ?Sized, Din, Dout> SliceInfo<T, Din, Dout>
where
T: AsRef<[AxisSliceInfo]>,
Din: Dimension,
Dout: Dimension,
{
/// Returns the number of dimensions of the input array for
/// [`.slice()`](struct.ArrayBase.html#method.slice).
///
Expand Down Expand Up @@ -528,7 +531,7 @@ where
}
}

impl<'a, Din, Dout> TryFrom<&'a [AxisSliceInfo]> for &'a SliceInfo<[AxisSliceInfo], Din, Dout>
impl<'a, Din, Dout> TryFrom<&'a [AxisSliceInfo]> for SliceInfo<&'a [AxisSliceInfo], Din, Dout>
where
Din: Dimension,
Dout: Dimension,
Expand All @@ -537,16 +540,11 @@ where

fn try_from(
indices: &'a [AxisSliceInfo],
) -> Result<&'a SliceInfo<[AxisSliceInfo], Din, Dout>, ShapeError> {
check_dims_for_sliceinfo::<Din, Dout>(indices)?;
) -> Result<SliceInfo<&'a [AxisSliceInfo], Din, Dout>, ShapeError> {
unsafe {
// This is okay because we've already checked the correctness of
// `Din` and `Dout`, and the only non-zero-sized member of
// `SliceInfo` is `indices`, so `&SliceInfo<[AxisSliceInfo], Din,
// Dout>` should have the same bitwise representation as
// `&[AxisSliceInfo]`.
Ok(&*(indices as *const [AxisSliceInfo]
as *const SliceInfo<[AxisSliceInfo], Din, Dout>))
// This is okay because `&[AxisSliceInfo]` always returns the same
// value for `.as_ref()`.
Self::new(indices)
}
}
}
Expand Down Expand Up @@ -612,20 +610,18 @@ where
}
}

impl<T, Din, Dout> AsRef<SliceInfo<[AxisSliceInfo], Din, Dout>> for SliceInfo<T, Din, Dout>
impl<'a, T, Din, Dout> From<&'a SliceInfo<T, Din, Dout>>
for SliceInfo<&'a [AxisSliceInfo], Din, Dout>
where
T: AsRef<[AxisSliceInfo]>,
Din: Dimension,
Dout: Dimension,
{
fn as_ref(&self) -> &SliceInfo<[AxisSliceInfo], Din, Dout> {
unsafe {
// This is okay because the only non-zero-sized member of
// `SliceInfo` is `indices`, so `&SliceInfo<[AxisSliceInfo], Din, Dout>`
// should have the same bitwise representation as
// `&[AxisSliceInfo]`.
&*(self.indices.as_ref() as *const [AxisSliceInfo]
as *const SliceInfo<[AxisSliceInfo], Din, Dout>)
fn from(info: &'a SliceInfo<T, Din, Dout>) -> SliceInfo<&'a [AxisSliceInfo], Din, Dout> {
SliceInfo {
in_dim: info.in_dim,
out_dim: info.out_dim,
indices: info.indices.as_ref(),
}
}
}
Expand Down Expand Up @@ -703,9 +699,7 @@ impl_slicenextdim!((), NewAxis, Ix0, Ix1);
///
/// `s![]` takes a list of ranges/slices/indices/new-axes, separated by comma,
/// with optional step sizes that are separated from the range by a semicolon.
/// It is converted into a [`&SliceInfo`] instance.
///
/// [`&SliceInfo`]: struct.SliceInfo.html
/// It is converted into a [`SliceInfo`] instance.
///
/// Each range/slice/index uses signed indices, where a negative value is
/// counted from the end of the axis. Step sizes are also signed and may be
Expand Down Expand Up @@ -889,9 +883,7 @@ macro_rules! s(
<$crate::AxisSliceInfo as ::std::convert::From<_>>::from($r).step_by($s as isize)
};
($($t:tt)*) => {
// The extra `*&` is a workaround for this compiler bug:
// https://github.com/rust-lang/rust/issues/23014
&*&$crate::s![@parse
$crate::s![@parse
::std::marker::PhantomData::<$crate::Ix0>,
::std::marker::PhantomData::<$crate::Ix0>,
[]
Expand Down Expand Up @@ -933,7 +925,7 @@ where
private_impl! {}
}

impl<'a, A, D, I0> MultiSliceArg<'a, A, D> for (&I0,)
impl<'a, A, D, I0> MultiSliceArg<'a, A, D> for (I0,)
where
A: 'a,
D: Dimension,
Expand All @@ -942,7 +934,7 @@ where
type Output = (ArrayViewMut<'a, A, I0::OutDim>,);

fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output {
(view.slice_move(self.0),)
(view.slice_move(&self.0),)
}

private_impl! {}
Expand All @@ -953,7 +945,7 @@ macro_rules! impl_multislice_tuple {
impl_multislice_tuple!(@def_impl ($($but_last,)* $last,), [$($but_last)*] $last);
};
(@def_impl ($($all:ident,)*), [$($but_last:ident)*] $last:ident) => {
impl<'a, A, D, $($all,)*> MultiSliceArg<'a, A, D> for ($(&$all,)*)
impl<'a, A, D, $($all,)*> MultiSliceArg<'a, A, D> for ($($all,)*)
where
A: 'a,
D: Dimension,
Expand All @@ -963,7 +955,7 @@ macro_rules! impl_multislice_tuple {

fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output {
#[allow(non_snake_case)]
let &($($all,)*) = self;
let ($($all,)*) = self;

let shape = view.raw_dim();
assert!(!impl_multislice_tuple!(@intersects_self &shape, ($($all,)*)));
Expand Down
12 changes: 6 additions & 6 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ fn test_slice_dyninput_array_fixed() {
#[test]
fn test_slice_array_dyn() {
let mut arr = Array3::<f64>::zeros((5, 2, 5));
let info = &SliceInfo::<_, Ix3, IxDyn>::try_from([
let info = SliceInfo::<_, Ix3, IxDyn>::try_from([
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(NewAxis),
Expand All @@ -229,7 +229,7 @@ fn test_slice_array_dyn() {
arr.slice(info);
arr.slice_mut(info);
arr.view().slice_move(info);
let info2 = &SliceInfo::<_, Ix3, IxDyn>::try_from([
let info2 = SliceInfo::<_, Ix3, IxDyn>::try_from([
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(..).step_by(2),
Expand All @@ -241,7 +241,7 @@ fn test_slice_array_dyn() {
#[test]
fn test_slice_dyninput_array_dyn() {
let mut arr = Array3::<f64>::zeros((5, 2, 5)).into_dyn();
let info = &SliceInfo::<_, Ix3, IxDyn>::try_from([
let info = SliceInfo::<_, Ix3, IxDyn>::try_from([
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(NewAxis),
Expand All @@ -251,7 +251,7 @@ fn test_slice_dyninput_array_dyn() {
arr.slice(info);
arr.slice_mut(info);
arr.view().slice_move(info);
let info2 = &SliceInfo::<_, Ix3, IxDyn>::try_from([
let info2 = SliceInfo::<_, Ix3, IxDyn>::try_from([
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(..).step_by(2),
Expand All @@ -273,7 +273,7 @@ fn test_slice_dyninput_vec_fixed() {
arr.slice(info);
arr.slice_mut(info);
arr.view().slice_move(info);
let info2 = &SliceInfo::<_, Ix3, Ix2>::try_from(vec![
let info2 = SliceInfo::<_, Ix3, Ix2>::try_from(vec![
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(..).step_by(2),
Expand All @@ -295,7 +295,7 @@ fn test_slice_dyninput_vec_dyn() {
arr.slice(info);
arr.slice_mut(info);
arr.view().slice_move(info);
let info2 = &SliceInfo::<_, Ix3, IxDyn>::try_from(vec![
let info2 = SliceInfo::<_, Ix3, IxDyn>::try_from(vec![
AxisSliceInfo::from(1..),
AxisSliceInfo::from(1),
AxisSliceInfo::from(..).step_by(2),
Expand Down
2 changes: 1 addition & 1 deletion tests/oper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ fn scaled_add_3() {

{
let mut av = a.slice_mut(s![..;s1, ..;s2]);
let c = c.slice(&SliceInfo::<_, IxDyn, IxDyn>::try_from(cslice).unwrap());
let c = c.slice(SliceInfo::<_, IxDyn, IxDyn>::try_from(cslice).unwrap());

let mut answerv = answer.slice_mut(s![..;s1, ..;s2]);
answerv += &(beta * &c);
Expand Down