Skip to content

Commit aefbb9a

Browse files
Support shift tuple when axis=None for roll
1 parent 50f1074 commit aefbb9a

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,9 @@ def roll(X, shift, axis=None):
426426
if not isinstance(X, dpt.usm_ndarray):
427427
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
428428
if axis is None:
429+
# get the combined shift value for all axes
430+
if type(shift) is tuple:
431+
shift = sum(shift)
429432
res = dpt.empty(
430433
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue
431434
)

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,8 @@ def test_roll_empty():
590590
"data",
591591
[
592592
[2, None],
593+
[(0, 1), None],
594+
[(-1, 0), None],
593595
[-2, None],
594596
[2, 0],
595597
[-2, 0],
@@ -617,6 +619,8 @@ def test_roll_1d(data):
617619
"data",
618620
[
619621
[1, None],
622+
[(2, 1), None],
623+
[(-1, 2), None],
620624
[1, 0],
621625
[1, 1],
622626
[1, ()],

0 commit comments

Comments
 (0)