Skip to content

Commit 37c63fd

Browse files
committed
Explicit cast shift to numpy array in dpnp.roll function
1 parent 681360e commit 37c63fd

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3207,10 +3207,14 @@ def roll(x, shift, axis=None):
32073207
[3, 4, 0, 1, 2]])
32083208
32093209
"""
3210-
if axis is None:
3211-
return roll(x.reshape(-1), shift, 0).reshape(x.shape)
32123210

32133211
usm_x = dpnp.get_usm_ndarray(x)
3212+
if dpnp.is_supported_array_type(shift):
3213+
shift = dpnp.asnumpy(shift)
3214+
3215+
if axis is None:
3216+
return roll(dpt.reshape(usm_x, -1), shift, 0).reshape(x.shape)
3217+
32143218
usm_res = dpt.roll(usm_x, shift=shift, axis=axis)
32153219
return dpnp_array._create_from_usm_ndarray(usm_res)
32163220

0 commit comments

Comments
 (0)