Skip to content

Commit 67cab69

Browse files
Deploy _copy_usm_ndarray_for_roll
Remove use of `shift=0` argument to `_copy_usm_ndarray_for_reshape` in _reshape.py Used `_copy_usm_ndarray_for_roll` in `roll` implementation.
1 parent 467d133 commit 67cab69

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def roll(X, shift, axis=None):
429429
res = dpt.empty(
430430
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue
431431
)
432-
hev, _ = ti._copy_usm_ndarray_for_reshape(
432+
hev, _ = ti._copy_usm_ndarray_for_roll(
433433
src=X, dst=res, shift=shift, sycl_queue=X.sycl_queue
434434
)
435435
hev.wait()
@@ -550,7 +550,6 @@ def _concat_axis_None(arrays):
550550
hev, _ = ti._copy_usm_ndarray_for_reshape(
551551
src=src_,
552552
dst=res[fill_start:fill_end],
553-
shift=0,
554553
sycl_queue=exec_q,
555554
)
556555
fill_start = fill_end

dpctl/tensor/_reshape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def reshape(X, shape, order="C", copy=None):
165165
)
166166
if order == "C":
167167
hev, _ = _copy_usm_ndarray_for_reshape(
168-
src=X, dst=flat_res, shift=0, sycl_queue=X.sycl_queue
168+
src=X, dst=flat_res, sycl_queue=X.sycl_queue
169169
)
170170
hev.wait()
171171
else:

0 commit comments

Comments
 (0)