Skip to content

Commit d58cba5

Browse files
Roll must reduce shift steps by size along axis
Closes gh-1857 Test added based on example provided in the issue.
1 parent 2f327af commit d58cba5

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def roll(X, /, shift, *, axis=None):
356356
hev, roll_ev = ti._copy_usm_ndarray_for_roll_1d(
357357
src=X,
358358
dst=res,
359-
shift=shift,
359+
shift=(shift % X.size),
360360
sycl_queue=exec_q,
361361
depends=dep_evs,
362362
)
@@ -369,9 +369,11 @@ def roll(X, /, shift, *, axis=None):
369369
shifts = [
370370
0,
371371
] * X.ndim
372+
shape = X.shape
372373
for sh, ax in broadcasted:
373-
shifts[ax] += sh
374-
374+
n_i = shape[ax]
375+
if n_i > 0:
376+
shifts[ax] = int(shifts[ax] + sh) % int(n_i)
375377
res = dpt.empty(
376378
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q
377379
)

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,16 @@ def test_roll_2d(data):
657657
assert_array_equal(Ynp, dpt.asnumpy(Y))
658658

659659

660+
def test_roll_out_bounds_shifts():
661+
"See gh-1857"
662+
get_queue_or_skip()
663+
664+
x = dpt.arange(4)
665+
y = dpt.roll(x, np.uint64(2**63 + 2))
666+
expected = dpt.roll(x, 2)
667+
assert dpt.all(y == expected)
668+
669+
660670
def test_roll_validation():
661671
get_queue_or_skip()
662672

0 commit comments

Comments
 (0)