Skip to content

Commit cb75694

Browse files
committed
safe stride constructor returns a Result
1 parent faf02c5 commit cb75694

File tree

5 files changed

+62
-13
lines changed

5 files changed

+62
-13
lines changed

src/dimension.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::slice;
22

33
use super::{Si, Ix, Ixs};
44
use super::zipsl;
5+
use stride_error::StrideError;
56

67
/// Calculate offset from `Ix` stride converting sign properly
78
#[inline]
@@ -68,7 +69,7 @@ pub fn dim_stride_overlap<D: Dimension>(dim: &D, strides: &D) -> bool
6869
pub fn can_index_slice<A, D: Dimension>(data: &[A],
6970
dim: &D,
7071
strides: &D
71-
) -> bool
72+
) -> Result<(), StrideError>
7273
{
7374
if strides.slice().iter().cloned().all(stride_is_positive) {
7475
let mut last_index = dim.clone();
@@ -79,16 +80,16 @@ pub fn can_index_slice<A, D: Dimension>(data: &[A],
7980
// offset is guaranteed to be positive so no issue converting
8081
// to usize here
8182
if (offset as usize) >= data.len() {
82-
return false;
83+
return Err(StrideError::OutOfBoundsStride);
8384
}
8485
if dim_stride_overlap(dim, strides) {
85-
return false;
86+
return Err(StrideError::OutOfBoundsStride);
8687
}
8788
}
88-
true
89+
Ok(())
8990
}
9091
else {
91-
false
92+
Err(StrideError::NegativeStride)
9293
}
9394
}
9495

src/lib.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ pub use dimension::{
8484
pub use dimension::NdIndex;
8585
pub use indexes::Indexes;
8686
pub use shape_error::ShapeError;
87+
pub use stride_error::StrideError;
8788
pub use si::{Si, S};
8889

8990
use dimension::stride_offset;
@@ -111,6 +112,7 @@ mod iterators;
111112
mod numeric_util;
112113
mod si;
113114
mod shape_error;
115+
mod stride_error;
114116

115117
// NOTE: In theory, the whole library should compile
116118
// and pass tests even if you change Ix and Ixs.
@@ -665,13 +667,15 @@ impl<S, A, D> ArrayBase<S, D>
665667
pub fn from_vec_dim_stride(dim: D,
666668
strides: D,
667669
v: Vec<A>
668-
) -> ArrayBase<S, D>
670+
) -> Result<ArrayBase<S, D>, StrideError>
669671
{
670-
assert!(dimension::can_index_slice(&v, &dim, &strides),
671-
"dim and strides index out of the vector's memory");
672-
unsafe {
673-
Self::from_vec_dim_stride_uchk(dim, strides, v)
674-
}
672+
dimension::can_index_slice(&v,
673+
&dim,
674+
&strides).map(|_| {
675+
unsafe {
676+
Self::from_vec_dim_stride_uchk(dim, strides, v)
677+
}
678+
})
675679
}
676680
}
677681

src/stride_error.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
use std::fmt;
2+
use std::error::Error;
3+
4+
/// An error to describe invalid stride states
5+
#[derive(Clone, Debug, PartialEq)]
6+
pub enum StrideError {
7+
/// stride leads to out of bounds indexing
8+
OutOfBoundsStride,
9+
/// stride leads to aliasing array elements
10+
AliasingStride,
11+
/// negative strides are unsafe in constructors
12+
NegativeStride,
13+
}
14+
15+
impl Error for StrideError {
16+
fn description(&self) -> &str {
17+
match *self {
18+
StrideError::OutOfBoundsStride =>
19+
"stride leads to out of bounds indexing",
20+
StrideError::AliasingStride =>
21+
"stride leads to aliasing array elements",
22+
StrideError::NegativeStride =>
23+
"negative strides are unsafe in constructors",
24+
}
25+
}
26+
}
27+
28+
impl fmt::Display for StrideError {
29+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
30+
self.description().fmt(f)
31+
}
32+
}

tests/array.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,16 @@ fn owned_array1() {
394394
assert_eq!(d1, d2);
395395
}
396396

397+
#[test]
398+
fn owned_array_with_stride() {
399+
let v: Vec<_> = (0..12).collect();
400+
let dim = (2, 3, 2);
401+
let strides = (1, 4, 2);
402+
403+
let a = OwnedArray::from_vec_dim_stride(dim, strides, v).unwrap();
404+
assert_eq!(a.strides(), &[1, 4, 2]);
405+
}
406+
397407
#[test]
398408
fn views() {
399409
let a = Array::from_vec(vec![1, 2, 3, 4]).reshape((2, 2));

tests/dimension.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use ndarray::{
66
arr2,
77
can_index_slice,
88
dim_stride_overlap,
9+
StrideError,
910
};
1011

1112
#[test]
@@ -44,10 +45,11 @@ fn slice_indexing_uncommon_strides()
4445
let v: Vec<_> = (0..12).collect();
4546
let dim = (2, 3, 2);
4647
let strides = (1, 2, 6);
47-
assert!(can_index_slice(&v, &dim, &strides));
48+
assert!(can_index_slice(&v, &dim, &strides).is_ok());
4849

4950
let strides = (2, 4, 12);
50-
assert!(!can_index_slice(&v, &dim, &strides));
51+
assert_eq!(can_index_slice(&v, &dim, &strides),
52+
Err(StrideError::OutOfBoundsStride));
5153
}
5254

5355
#[test]

0 commit comments

Comments
 (0)