Skip to content

Commit fdf9ada

Browse files
committed
Fixes diff for Python scalar append or prepend
Typos caused this to result in incorrect shaped outputs
1 parent 316254b commit fdf9ada

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

dpctl/tensor/_utility_functions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,10 +287,10 @@ def _concat_diff_input(arr, axis, prepend, append):
287287
)
288288
if not prepend_shape:
289289
prepend_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :]
290-
a_prepend = dpt.broadcast_to(a_prepend, arr_shape)
290+
a_prepend = dpt.broadcast_to(a_prepend, prepend_shape)
291291
if not append_shape:
292292
append_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :]
293-
a_append = dpt.broadcast_to(a_append, arr_shape)
293+
a_append = dpt.broadcast_to(a_append, append_shape)
294294
return dpt.concat((a_prepend, arr, a_append), axis=axis)
295295
elif prepend is not None:
296296
q1, x_usm_type = arr.sycl_queue, arr.usm_type
@@ -347,7 +347,7 @@ def _concat_diff_input(arr, axis, prepend, append):
347347
)
348348
if not prepend_shape:
349349
prepend_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :]
350-
a_prepend = dpt.broadcast_to(a_prepend, arr_shape)
350+
a_prepend = dpt.broadcast_to(a_prepend, prepend_shape)
351351
return dpt.concat((a_prepend, arr), axis=axis)
352352
elif append is not None:
353353
q1, x_usm_type = arr.sycl_queue, arr.usm_type
@@ -402,7 +402,7 @@ def _concat_diff_input(arr, axis, prepend, append):
402402
)
403403
if not append_shape:
404404
append_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :]
405-
a_append = dpt.broadcast_to(a_append, arr_shape)
405+
a_append = dpt.broadcast_to(a_append, append_shape)
406406
return dpt.concat((arr, a_append), axis=axis)
407407
else:
408408
arr1 = arr

0 commit comments

Comments
 (0)