Skip to content

Commit efc5f04

Browse files
committed
fix negative strides
1 parent 5056c49 commit efc5f04

File tree

3 files changed

+32
-7
lines changed

3 files changed

+32
-7
lines changed

dpnp/backend/extensions/blas/dot.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,15 @@ std::pair<sycl::event, sycl::event> dot(sycl::queue exec_q,
193193
char *b_typeless_ptr = vectorB.get_data();
194194
char *r_typeless_ptr = result.get_data();
195195

196+
const int a_elemsize = vectorA.get_elemsize();
197+
const int b_elemsize = vectorB.get_elemsize();
198+
if (str_a < 0) {
199+
a_typeless_ptr -= (n - 1) * std::abs(str_a) * a_elemsize;
200+
}
201+
if (str_b < 0) {
202+
b_typeless_ptr -= (n - 1) * std::abs(str_b) * b_elemsize;
203+
}
204+
196205
sycl::event dot_ev = dot_fn(exec_q, n, a_typeless_ptr, str_a,
197206
b_typeless_ptr, str_b, r_typeless_ptr, depends);
198207

dpnp/backend/extensions/blas/dotu.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,15 @@ std::pair<sycl::event, sycl::event>
195195
char *b_typeless_ptr = vectorB.get_data();
196196
char *r_typeless_ptr = result.get_data();
197197

198+
const int a_elemsize = vectorA.get_elemsize();
199+
const int b_elemsize = vectorB.get_elemsize();
200+
if (str_a < 0) {
201+
a_typeless_ptr -= (n - 1) * std::abs(str_a) * a_elemsize;
202+
}
203+
if (str_b < 0) {
204+
b_typeless_ptr -= (n - 1) * std::abs(str_b) * b_elemsize;
205+
}
206+
198207
sycl::event dotu_ev =
199208
dotu_fn(exec_q, n, a_typeless_ptr, str_a, b_typeless_ptr, str_b,
200209
r_typeless_ptr, depends);

tests/test_dot.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,21 +181,28 @@ def test_dot_ndarray(dtype, array_info):
181181
assert_dtype_allclose(result, expected)
182182

183183

184-
@pytest.mark.parametrize("dtype", get_all_dtypes())
184+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
185185
def test_dot_strided(dtype):
186-
a = numpy.ones(25, dtype=dtype)
187-
b = numpy.ones(25, dtype=dtype)
186+
a = numpy.arange(25, dtype=dtype)
187+
b = numpy.arange(25, dtype=dtype)
188188
ia = dpnp.array(a)
189189
ib = dpnp.array(b)
190190

191191
result = dpnp.dot(ia[::3], ib[::3])
192192
expected = numpy.dot(a[::3], b[::3])
193193
assert_dtype_allclose(result, expected)
194194

195-
# TODO: unmute when the problem with negative stride is fixed
196-
# result = dpnp.dot(ia, ib[::-1])
197-
# expected = numpy.dot(a, b[::-1])
198-
# assert_dtype_allclose(result, expected)
195+
result = dpnp.dot(ia, ib[::-1])
196+
expected = numpy.dot(a, b[::-1])
197+
assert_dtype_allclose(result, expected)
198+
199+
result = dpnp.dot(ia[::-2], ib[::-2])
200+
expected = numpy.dot(a[::-2], b[::-2])
201+
assert_dtype_allclose(result, expected)
202+
203+
result = dpnp.dot(ia[::-5], ib[::-5])
204+
expected = numpy.dot(a[::-5], b[::-5])
205+
assert_dtype_allclose(result, expected)
199206

200207

201208
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))

0 commit comments

Comments
 (0)