Skip to content

Commit 6054b1f

Browse files
Fix division by zero exception found by array API tests suite
1 parent a3d3509 commit 6054b1f

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,10 +353,12 @@ def roll(X, /, shift, *, axis=None):
353353
res = dpt.empty(
354354
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q
355355
)
356+
sz = X.size
357+
shift = (shift % sz) if sz > 0 else 0
356358
hev, roll_ev = ti._copy_usm_ndarray_for_roll_1d(
357359
src=X,
358360
dst=res,
359-
shift=(shift % X.size),
361+
shift=shift,
360362
sycl_queue=exec_q,
361363
depends=dep_evs,
362364
)

0 commit comments

Comments
 (0)