Skip to content

Commit f0dafb3

Browse files
SparrowLiibluss
authored andcommitted
FEAT: Fix memory continuity judgment when stride is negative
1 parent 4d9641d commit f0dafb3

File tree

7 files changed

+76
-39
lines changed

7 files changed

+76
-39
lines changed

blas-tests/tests/oper.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ where
173173
S2: Data<Elem = A>,
174174
{
175175
let ((m, _), k) = (lhs.dim(), rhs.dim());
176-
reference_mat_mul(lhs, &rhs.to_owned().into_shape((k, 1)).unwrap())
176+
reference_mat_mul(lhs, &rhs.as_standard_layout().into_shape((k, 1)).unwrap())
177177
.into_shape(m)
178178
.unwrap()
179179
}
@@ -186,7 +186,7 @@ where
186186
S2: Data<Elem = A>,
187187
{
188188
let (m, (_, n)) = (lhs.dim(), rhs.dim());
189-
reference_mat_mul(&lhs.to_owned().into_shape((1, m)).unwrap(), rhs)
189+
reference_mat_mul(&lhs.as_standard_layout().into_shape((1, m)).unwrap(), rhs)
190190
.into_shape(n)
191191
.unwrap()
192192
}

src/dimension/dimension_trait.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -286,17 +286,16 @@ pub trait Dimension:
286286
return true;
287287
}
288288
if dim.ndim() == 1 {
289-
return false;
289+
return strides[0] as isize == -1;
290290
}
291291
let order = strides._fastest_varying_stride_order();
292292
let strides = strides.slice();
293293

294-
// FIXME: Negative strides
295294
let dim_slice = dim.slice();
296295
let mut cstride = 1;
297296
for &i in order.slice() {
298297
// a dimension of length 1 can have unequal strides
299-
if dim_slice[i] != 1 && strides[i] != cstride {
298+
if dim_slice[i] != 1 && (strides[i] as isize).abs() as usize != cstride {
300299
return false;
301300
}
302301
cstride *= dim_slice[i];
@@ -307,16 +306,16 @@ pub trait Dimension:
307306
/// Return the axis ordering corresponding to the fastest variation
308307
/// (in ascending order).
309308
///
310-
/// Assumes that no stride value appears twice. This cannot yield the correct
311-
/// result the strides are not positive.
309+
/// Assumes that no stride value appears twice.
312310
#[doc(hidden)]
313311
fn _fastest_varying_stride_order(&self) -> Self {
314312
let mut indices = self.clone();
315313
for (i, elt) in enumerate(indices.slice_mut()) {
316314
*elt = i;
317315
}
318316
let strides = self.slice();
319-
indices.slice_mut().sort_by_key(|&i| strides[i]);
317+
indices.slice_mut()
318+
.sort_by_key(|&i| (strides[i] as isize).abs());
320319
indices
321320
}
322321

@@ -645,7 +644,7 @@ impl Dimension for Dim<[Ix; 2]> {
645644

646645
#[inline]
647646
fn _fastest_varying_stride_order(&self) -> Self {
648-
if get!(self, 0) as Ixs <= get!(self, 1) as Ixs {
647+
if (get!(self, 0) as Ixs).abs() <= (get!(self, 1) as Ixs).abs() {
649648
Ix2(0, 1)
650649
} else {
651650
Ix2(1, 0)
@@ -805,7 +804,7 @@ impl Dimension for Dim<[Ix; 3]> {
805804
let mut order = Ix3(0, 1, 2);
806805
macro_rules! swap {
807806
($stride:expr, $order:expr, $x:expr, $y:expr) => {
808-
if $stride[$x] > $stride[$y] {
807+
if ($stride[$x] as isize).abs() > ($stride[$y] as isize).abs() {
809808
$stride.swap($x, $y);
810809
$order.ixm().swap($x, $y);
811810
}

src/dimension/mod.rs

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,12 @@ pub fn stride_offset(n: Ix, stride: Ix) -> isize {
4646
/// There is overlap if, when iterating through the dimensions in order of
4747
/// increasing stride, the current stride is less than or equal to the maximum
4848
/// possible offset along the preceding axes. (Axes of length ≤1 are ignored.)
49-
///
50-
/// The current implementation assumes that strides of axes with length > 1 are
51-
/// nonnegative. Additionally, it does not check for overflow.
5249
pub fn dim_stride_overlap<D: Dimension>(dim: &D, strides: &D) -> bool {
5350
let order = strides._fastest_varying_stride_order();
5451
let mut sum_prev_offsets = 0;
5552
for &index in order.slice() {
5653
let d = dim[index];
57-
let s = strides[index] as isize;
54+
let s = (strides[index] as isize).abs();
5855
match d {
5956
0 => return false,
6057
1 => {}
@@ -210,8 +207,7 @@ where
210207
///
211208
/// 2. The product of non-zero axis lengths must not exceed `isize::MAX`.
212209
///
213-
/// 3. For axes with length > 1, the stride must be nonnegative. This is
214-
/// necessary to make sure the pointer cannot move backwards outside the
210+
/// 3. For axes with length > 1, the pointer cannot move outside the
215211
/// slice. For axes with length ≤ 1, the stride can be anything.
216212
///
217213
/// 4. If the array will be empty (any axes are zero-length), the difference
@@ -257,14 +253,6 @@ fn can_index_slice_impl<D: Dimension>(
257253
return Err(from_kind(ErrorKind::OutOfBounds));
258254
}
259255

260-
// Check condition 3.
261-
for (&d, &s) in izip!(dim.slice(), strides.slice()) {
262-
let s = s as isize;
263-
if d > 1 && s < 0 {
264-
return Err(from_kind(ErrorKind::Unsupported));
265-
}
266-
}
267-
268256
// Check condition 5.
269257
if !is_empty && dim_stride_overlap(dim, strides) {
270258
return Err(from_kind(ErrorKind::Unsupported));
@@ -394,6 +382,19 @@ fn to_abs_slice(axis_len: usize, slice: Slice) -> (usize, usize, isize) {
394382
(start, end, step)
395383
}
396384

385+
/// This function computes the offset from the logically first element to the first element in
386+
/// memory of the array. The result is always <= 0.
387+
pub fn offset_from_ptr_to_memory(dim: &[Ix], strides: &[Ix]) -> isize {
388+
let offset = izip!(dim, strides).fold(0, |_offset, (d, s)| {
389+
if (*s as isize) < 0 {
390+
_offset + *s as isize * (*d as isize - 1)
391+
} else {
392+
_offset
393+
}
394+
});
395+
offset
396+
}
397+
397398
/// Modify dimension, stride and return data pointer offset
398399
///
399400
/// **Panics** if stride is 0 or if any index is out of bounds.
@@ -693,13 +694,21 @@ mod test {
693694
let dim = (2, 3, 2).into_dimension();
694695
let strides = (5, 2, 1).into_dimension();
695696
assert!(super::dim_stride_overlap(&dim, &strides));
697+
let strides = (-5isize as usize, 2, -1isize as usize).into_dimension();
698+
assert!(super::dim_stride_overlap(&dim, &strides));
696699
let strides = (6, 2, 1).into_dimension();
697700
assert!(!super::dim_stride_overlap(&dim, &strides));
701+
let strides = (6, -2isize as usize, 1).into_dimension();
702+
assert!(!super::dim_stride_overlap(&dim, &strides));
698703
let strides = (6, 0, 1).into_dimension();
699704
assert!(super::dim_stride_overlap(&dim, &strides));
705+
let strides = (-6isize as usize, 0, 1).into_dimension();
706+
assert!(super::dim_stride_overlap(&dim, &strides));
700707
let dim = (2, 2).into_dimension();
701708
let strides = (3, 2).into_dimension();
702709
assert!(!super::dim_stride_overlap(&dim, &strides));
710+
let strides = (3, -2isize as usize).into_dimension();
711+
assert!(!super::dim_stride_overlap(&dim, &strides));
703712
}
704713

705714
#[test]
@@ -736,7 +745,7 @@ mod test {
736745
can_index_slice::<i32, _>(&[1], &Ix1(2), &Ix1(1)).unwrap_err();
737746
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(0)).unwrap_err();
738747
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(1)).unwrap();
739-
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(-1isize as usize)).unwrap_err();
748+
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(-1isize as usize)).unwrap();
740749
}
741750

742751
#[test]

src/impl_constructors.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use num_traits::{Float, One, Zero};
1616
use std::mem::MaybeUninit;
1717

1818
use crate::dimension;
19+
use crate::dimension::offset_from_ptr_to_memory;
1920
use crate::error::{self, ShapeError};
2021
use crate::extension::nonnull::nonnull_from_vec_data;
2122
use crate::imp_prelude::*;
@@ -24,6 +25,7 @@ use crate::indices;
2425
use crate::iterators::{to_vec, to_vec_mapped};
2526
use crate::StrideShape;
2627
use crate::{geomspace, linspace, logspace};
28+
use rawpointer::PointerExt;
2729

2830
/// # Constructor Methods for Owned Arrays
2931
///
@@ -431,7 +433,8 @@ where
431433
///
432434
/// 2. The product of non-zero axis lengths must not exceed `isize::MAX`.
433435
///
434-
/// 3. For axes with length > 1, the stride must be nonnegative.
436+
/// 3. For axes with length > 1, the pointer cannot move outside the
437+
/// slice.
435438
///
436439
/// 4. If the array will be empty (any axes are zero-length), the
437440
/// difference between the least address and greatest address accessible
@@ -457,7 +460,8 @@ where
457460
// debug check for issues that indicates wrong use of this constructor
458461
debug_assert!(dimension::can_index_slice(&v, &dim, &strides).is_ok());
459462
ArrayBase {
460-
ptr: nonnull_from_vec_data(&mut v),
463+
ptr: nonnull_from_vec_data(&mut v)
464+
.offset(offset_from_ptr_to_memory(dim.slice(), strides.slice()).abs()),
461465
data: DataOwned::new(v),
462466
strides,
463467
dim,
@@ -483,7 +487,7 @@ where
483487
///
484488
/// This constructor is limited to elements where `A: Copy` (no destructors)
485489
/// to avoid users shooting themselves too hard in the foot.
486-
///
490+
///
487491
/// (Also note that the constructors `from_shape_vec` and
488492
/// `from_shape_vec_unchecked` allow the user yet more control, in the sense
489493
/// that Arrays can be created from arbitrary vectors.)

src/impl_methods.rs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ use crate::arraytraits;
1717
use crate::dimension;
1818
use crate::dimension::IntoDimension;
1919
use crate::dimension::{
20-
abs_index, axes_of, do_slice, merge_axes, size_of_shape_checked, stride_offset, Axes,
20+
abs_index, axes_of, do_slice, merge_axes, offset_from_ptr_to_memory, size_of_shape_checked,
21+
stride_offset, Axes,
2122
};
2223
use crate::error::{self, ErrorKind, ShapeError};
2324
use crate::itertools::zip;
@@ -1280,9 +1281,6 @@ where
12801281
}
12811282

12821283
/// Return true if the array is known to be contiguous.
1283-
///
1284-
/// Will detect c- and f-contig arrays correctly, but otherwise
1285-
/// There are some false negatives.
12861284
pub(crate) fn is_contiguous(&self) -> bool {
12871285
D::is_contiguous(&self.dim, &self.strides)
12881286
}
@@ -1404,14 +1402,18 @@ where
14041402
///
14051403
/// If this function returns `Some(_)`, then the elements in the slice
14061404
/// have whatever order the elements have in memory.
1407-
///
1408-
/// Implementation notes: Does not yet support negatively strided arrays.
14091405
pub fn as_slice_memory_order(&self) -> Option<&[A]>
14101406
where
14111407
S: Data,
14121408
{
14131409
if self.is_contiguous() {
1414-
unsafe { Some(slice::from_raw_parts(self.ptr.as_ptr(), self.len())) }
1410+
let offset = offset_from_ptr_to_memory(self.dim.slice(), self.strides.slice());
1411+
unsafe {
1412+
Some(slice::from_raw_parts(
1413+
self.ptr.offset(offset).as_ptr(),
1414+
self.len(),
1415+
))
1416+
}
14151417
} else {
14161418
None
14171419
}
@@ -1425,7 +1427,13 @@ where
14251427
{
14261428
if self.is_contiguous() {
14271429
self.ensure_unique();
1428-
unsafe { Some(slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len())) }
1430+
let offset = offset_from_ptr_to_memory(self.dim.slice(), self.strides.slice());
1431+
unsafe {
1432+
Some(slice::from_raw_parts_mut(
1433+
self.ptr.offset(offset).as_ptr(),
1434+
self.len(),
1435+
))
1436+
}
14291437
} else {
14301438
None
14311439
}

tests/dimension.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,36 @@ fn fastest_varying_order() {
118118
let order = strides._fastest_varying_stride_order();
119119
assert_eq!(order.slice(), &[3, 0, 2, 1]);
120120

121+
let strides = Dim([-2isize as usize, 8, -4isize as usize, -1isize as usize]);
122+
let order = strides._fastest_varying_stride_order();
123+
assert_eq!(order.slice(), &[3, 0, 2, 1]);
124+
121125
assert_eq!(Dim([1, 3])._fastest_varying_stride_order(), Dim([0, 1]));
126+
assert_eq!(
127+
Dim([1, -3isize as usize])._fastest_varying_stride_order(),
128+
Dim([0, 1])
129+
);
122130
assert_eq!(Dim([7, 2])._fastest_varying_stride_order(), Dim([1, 0]));
131+
assert_eq!(
132+
Dim([-7isize as usize, 2])._fastest_varying_stride_order(),
133+
Dim([1, 0])
134+
);
123135
assert_eq!(
124136
Dim([6, 1, 3])._fastest_varying_stride_order(),
125137
Dim([1, 2, 0])
126138
);
139+
assert_eq!(
140+
Dim([-6isize as usize, 1, -3isize as usize])._fastest_varying_stride_order(),
141+
Dim([1, 2, 0])
142+
);
127143

128144
// it's important that it produces distinct indices. Prefer the stable order
129145
// where 0 is before 1 when they are equal.
130146
assert_eq!(Dim([2, 2])._fastest_varying_stride_order(), [0, 1]);
131147
assert_eq!(Dim([2, 2, 1])._fastest_varying_stride_order(), [2, 0, 1]);
132148
assert_eq!(
133-
Dim([2, 2, 3, 1, 2])._fastest_varying_stride_order(),
149+
Dim([-2isize as usize, -2isize as usize, 3, 1, -2isize as usize])
150+
._fastest_varying_stride_order(),
134151
[3, 0, 1, 4, 2]
135152
);
136153
}

tests/oper.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ fn gen_mat_vec_mul() {
707707
S2: Data<Elem = A>,
708708
{
709709
let ((m, _), k) = (lhs.dim(), rhs.dim());
710-
reference_mat_mul(lhs, &rhs.to_owned().into_shape((k, 1)).unwrap())
710+
reference_mat_mul(lhs, &rhs.as_standard_layout().into_shape((k, 1)).unwrap())
711711
.into_shape(m)
712712
.unwrap()
713713
}
@@ -772,7 +772,7 @@ fn vec_mat_mul() {
772772
S2: Data<Elem = A>,
773773
{
774774
let (m, (_, n)) = (lhs.dim(), rhs.dim());
775-
reference_mat_mul(&lhs.to_owned().into_shape((1, m)).unwrap(), rhs)
775+
reference_mat_mul(&lhs.as_standard_layout().into_shape((1, m)).unwrap(), rhs)
776776
.into_shape(n)
777777
.unwrap()
778778
}

0 commit comments

Comments
 (0)