Skip to content

Commit 6d960a4

Browse files
committed
Merge pull request #60 from vbarrielle/stride_constructor
Implement owning constructors taking dimensions and strides
2 parents 3b0dfe4 + b766498 commit 6d960a4

File tree

4 files changed

+344
-1
lines changed

4 files changed

+344
-1
lines changed

src/dimension.rs

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::slice;
22

33
use super::{Si, Ix, Ixs};
44
use super::zipsl;
5+
use stride_error::StrideError;
56

67
/// Calculate offset from `Ix` stride converting sign properly
78
#[inline]
@@ -10,6 +11,88 @@ pub fn stride_offset(n: Ix, stride: Ix) -> isize
1011
(n as isize) * ((stride as Ixs) as isize)
1112
}
1213

14+
/// Check whether `stride` is strictly positive
15+
#[inline]
16+
pub fn stride_is_positive(stride: Ix) -> bool
17+
{
18+
(stride as Ixs) > 0
19+
}
20+
21+
/// Return the axis ordering corresponding to the fastest variation
22+
///
23+
/// Assumes that no stride value appears twice. This cannot yield the correct
24+
/// result the strides are not positive.
25+
fn fastest_varying_order<D: Dimension>(strides: &D) -> D
26+
{
27+
let mut sorted = strides.clone();
28+
sorted.slice_mut().sort();
29+
let mut res = strides.clone();
30+
for (ind, &val) in strides.slice().iter().enumerate() {
31+
let sorted_ind = sorted.slice()
32+
.iter()
33+
.position(|&x| x == val)
34+
.unwrap(); // cannot panic by construction
35+
res.slice_mut()[sorted_ind] = ind;
36+
}
37+
res
38+
}
39+
40+
/// Check whether the given `dim` and `stride` lead to overlapping indices
41+
///
42+
/// There is overlap if, when iterating through the dimensions in the order
43+
/// of maximum variation, the current stride is inferior to the sum of all
44+
/// preceding strides multiplied by their corresponding dimensions.
45+
///
46+
/// The current implementation assumes strides to be positive
47+
pub fn dim_stride_overlap<D: Dimension>(dim: &D, strides: &D) -> bool
48+
{
49+
let order = fastest_varying_order(strides);
50+
51+
let mut prev_offset = 1;
52+
for &ind in order.slice().iter() {
53+
let s = strides.slice()[ind];
54+
if (s as isize) < prev_offset {
55+
return true;
56+
}
57+
prev_offset = stride_offset(dim.slice()[ind], s);
58+
}
59+
false
60+
}
61+
62+
/// Check whether the given dimension and strides are memory safe
63+
/// to index the provided slice.
64+
///
65+
/// To be safe, no stride may be negative, and the offset corresponding
66+
/// to the last element of each dimension should be smaller than the length
67+
/// of the slice. Also, the strides should not allow a same element to be
68+
/// referenced by two different index.
69+
pub fn can_index_slice<A, D: Dimension>(data: &[A],
70+
dim: &D,
71+
strides: &D
72+
) -> Result<(), StrideError>
73+
{
74+
if strides.slice().iter().cloned().all(stride_is_positive) {
75+
let mut last_index = dim.clone();
76+
for mut index in last_index.slice_mut().iter_mut() {
77+
*index -= 1;
78+
}
79+
if let Some(offset) = dim.stride_offset_checked(strides, &last_index) {
80+
// offset is guaranteed to be positive so no issue converting
81+
// to usize here
82+
if (offset as usize) >= data.len() {
83+
return Err(StrideError::OutOfBounds);
84+
}
85+
if dim_stride_overlap(dim, strides) {
86+
return Err(StrideError::Aliasing);
87+
}
88+
}
89+
Ok(())
90+
}
91+
else {
92+
Err(StrideError::Aliasing)
93+
}
94+
}
95+
1396
/// Trait for the shape and index types of arrays.
1497
///
1598
/// `unsafe` because of the assumptions in the default methods.
@@ -69,6 +152,26 @@ pub unsafe trait Dimension : Clone + Eq {
69152
strides
70153
}
71154

155+
fn fortran_strides(&self) -> Self {
156+
// Compute fortran array strides
157+
// Shape (a, b, c) => Give strides (1, a, a * b)
158+
let mut strides = self.clone();
159+
{
160+
let mut it = strides.slice_mut().iter_mut();
161+
// Set first element to 1
162+
for rs in it.by_ref() {
163+
*rs = 1;
164+
break;
165+
}
166+
let mut cum_prod = 1;
167+
for (rs, dim) in it.zip(self.slice().iter()) {
168+
cum_prod *= *dim;
169+
*rs = cum_prod;
170+
}
171+
}
172+
strides
173+
}
174+
72175
#[inline]
73176
fn first_index(&self) -> Option<Self>
74177
{
@@ -529,3 +632,39 @@ unsafe impl<'a> NdIndex for &'a [Ix] {
529632
Some(offset)
530633
}
531634
}
635+
636+
#[cfg(test)]
637+
mod test {
638+
use super::{Dimension};
639+
use stride_error::StrideError;
640+
641+
#[test]
642+
fn fastest_varying_order() {
643+
let strides = (2, 8, 4, 1);
644+
let order = super::fastest_varying_order(&strides);
645+
assert_eq!(order.slice(), &[3, 0, 2, 1]);
646+
}
647+
648+
#[test]
649+
fn slice_indexing_uncommon_strides()
650+
{
651+
let v: Vec<_> = (0..12).collect();
652+
let dim = (2, 3, 2);
653+
let strides = (1, 2, 6);
654+
assert!(super::can_index_slice(&v, &dim, &strides).is_ok());
655+
656+
let strides = (2, 4, 12);
657+
assert_eq!(super::can_index_slice(&v, &dim, &strides),
658+
Err(StrideError::OutOfBounds));
659+
}
660+
661+
#[test]
662+
fn overlapping_strides_dim()
663+
{
664+
let dim = (2, 3, 2);
665+
let strides = (5, 2, 1);
666+
assert!(super::dim_stride_overlap(&dim, &strides));
667+
let strides = (6, 2, 1);
668+
assert!(!super::dim_stride_overlap(&dim, &strides));
669+
}
670+
}

src/lib.rs

Lines changed: 167 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,15 @@ use std::marker::PhantomData;
7575

7676
use itertools::ZipSlices;
7777

78-
pub use dimension::{Dimension, RemoveAxis};
78+
pub use dimension::{
79+
Dimension,
80+
RemoveAxis,
81+
};
82+
7983
pub use dimension::NdIndex;
8084
pub use indexes::Indexes;
8185
pub use shape_error::ShapeError;
86+
pub use stride_error::StrideError;
8287
pub use si::{Si, S};
8388

8489
use dimension::stride_offset;
@@ -107,6 +112,7 @@ mod linspace;
107112
mod numeric_util;
108113
mod si;
109114
mod shape_error;
115+
mod stride_error;
110116

111117
// NOTE: In theory, the whole library should compile
112118
// and pass tests even if you change Ix and Ixs.
@@ -605,12 +611,43 @@ impl<S, A, D> ArrayBase<S, D>
605611
}
606612
}
607613

614+
/// Create an array with copies of `elem`, dimension `dim` and fortran
615+
/// ordering.
616+
///
617+
/// ```
618+
/// use ndarray::Array;
619+
/// use ndarray::arr3;
620+
///
621+
/// let a = Array::from_elem_f((2, 2, 2), 1.);
622+
///
623+
/// assert!(
624+
/// a == arr3(&[[[1., 1.],
625+
/// [1., 1.]],
626+
/// [[1., 1.],
627+
/// [1., 1.]]])
628+
/// );
629+
/// assert!(a.strides() == &[1, 2, 4]);
630+
/// ```
631+
pub fn from_elem_f(dim: D, elem: A) -> ArrayBase<S, D> where A: Clone
632+
{
633+
let v = vec![elem; dim.size()];
634+
unsafe {
635+
Self::from_vec_dim_f(dim, v)
636+
}
637+
}
638+
608639
/// Create an array with zeros, dimension `dim`.
609640
pub fn zeros(dim: D) -> ArrayBase<S, D> where A: Clone + libnum::Zero
610641
{
611642
Self::from_elem(dim, libnum::zero())
612643
}
613644

645+
/// Create an array with zeros, dimension `dim` and fortran ordering.
646+
pub fn zeros_f(dim: D) -> ArrayBase<S, D> where A: Clone + libnum::Zero
647+
{
648+
Self::from_elem_f(dim, libnum::zero())
649+
}
650+
614651
/// Create an array with default values, dimension `dim`.
615652
pub fn default(dim: D) -> ArrayBase<S, D>
616653
where A: Default
@@ -634,6 +671,58 @@ impl<S, A, D> ArrayBase<S, D>
634671
dim: dim
635672
}
636673
}
674+
675+
/// Create an array from a vector (with no allocation needed),
676+
/// using fortran ordering to interpret the data.
677+
///
678+
/// Unsafe because dimension is unchecked, and must be correct.
679+
pub unsafe fn from_vec_dim_f(dim: D, mut v: Vec<A>) -> ArrayBase<S, D>
680+
{
681+
debug_assert!(dim.size() == v.len());
682+
ArrayBase {
683+
ptr: v.as_mut_ptr(),
684+
data: DataOwned::new(v),
685+
strides: dim.fortran_strides(),
686+
dim: dim
687+
}
688+
}
689+
690+
691+
/// Create an array from a vector and interpret it according to the
692+
/// provided dimensions and strides. No allocation needed.
693+
///
694+
/// Unsafe because dimension and strides are unchecked.
695+
pub unsafe fn from_vec_dim_stride_uchk(dim: D,
696+
strides: D,
697+
mut v: Vec<A>
698+
) -> ArrayBase<S, D>
699+
{
700+
ArrayBase {
701+
ptr: v.as_mut_ptr(),
702+
data: DataOwned::new(v),
703+
strides: strides,
704+
dim: dim
705+
}
706+
}
707+
708+
/// Create an array from a vector and interpret it according to the
709+
/// provided dimensions and strides. No allocation needed.
710+
///
711+
/// Checks whether `dim` and `strides` are compatible with the vector's
712+
/// length, returning an `Err` if not compatible.
713+
pub fn from_vec_dim_stride(dim: D,
714+
strides: D,
715+
v: Vec<A>
716+
) -> Result<ArrayBase<S, D>, StrideError>
717+
{
718+
dimension::can_index_slice(&v,
719+
&dim,
720+
&strides).map(|_| {
721+
unsafe {
722+
Self::from_vec_dim_stride_uchk(dim, strides, v)
723+
}
724+
})
725+
}
637726
}
638727

639728

@@ -666,6 +755,44 @@ impl<'a, A, D> ArrayView<'a, A, D>
666755
}
667756
}
668757

758+
/// Create an `ArrayView` borrowing its data from a slice.
759+
///
760+
/// Checks whether `dim` and `strides` are compatible with the slice's
761+
/// length, returning an `Err` if not compatible.
762+
///
763+
/// ```
764+
/// use ndarray::ArrayView;
765+
/// use ndarray::arr3;
766+
///
767+
/// let s = &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
768+
/// let a = ArrayView::from_slice_dim_stride((2, 3, 2),
769+
/// (1, 4, 2),
770+
/// s).unwrap();
771+
///
772+
/// assert!(
773+
/// a == arr3(&[[[0, 2],
774+
/// [4, 6],
775+
/// [8, 10]],
776+
/// [[1, 3],
777+
/// [5, 7],
778+
/// [9, 11]]])
779+
/// );
780+
/// assert!(a.strides() == &[1, 4, 2]);
781+
/// ```
782+
pub fn from_slice_dim_stride(dim: D,
783+
strides: D,
784+
s: &'a [A]
785+
) -> Result<Self, StrideError>
786+
{
787+
dimension::can_index_slice(s,
788+
&dim,
789+
&strides).map(|_| {
790+
unsafe {
791+
Self::new_(s.as_ptr(), dim, strides)
792+
}
793+
})
794+
}
795+
669796
#[inline]
670797
fn into_base_iter(self) -> Baseiter<'a, A, D> {
671798
unsafe {
@@ -723,6 +850,45 @@ impl<'a, A, D> ArrayViewMut<'a, A, D>
723850
}
724851
}
725852

853+
/// Create an `ArrayView` borrowing its data from a slice.
854+
///
855+
/// Checks whether `dim` and `strides` are compatible with the slice's
856+
/// length, returning an `Err` if not compatible.
857+
///
858+
/// ```
859+
/// use ndarray::ArrayViewMut;
860+
/// use ndarray::arr3;
861+
///
862+
/// let s = &mut [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
863+
/// let mut a = ArrayViewMut::from_slice_dim_stride((2, 3, 2),
864+
/// (1, 4, 2),
865+
/// s).unwrap();
866+
///
867+
/// a[[0, 0, 0]] = 1;
868+
/// assert!(
869+
/// a == arr3(&[[[1, 2],
870+
/// [4, 6],
871+
/// [8, 10]],
872+
/// [[1, 3],
873+
/// [5, 7],
874+
/// [9, 11]]])
875+
/// );
876+
/// assert!(a.strides() == &[1, 4, 2]);
877+
/// ```
878+
pub fn from_slice_dim_stride(dim: D,
879+
strides: D,
880+
s: &'a mut [A]
881+
) -> Result<Self, StrideError>
882+
{
883+
dimension::can_index_slice(s,
884+
&dim,
885+
&strides).map(|_| {
886+
unsafe {
887+
Self::new_(s.as_mut_ptr(), dim, strides)
888+
}
889+
})
890+
}
891+
726892
#[inline]
727893
fn into_base_iter(self) -> Baseiter<'a, A, D> {
728894
unsafe {

0 commit comments

Comments
 (0)