@@ -287,10 +287,10 @@ def _concat_diff_input(arr, axis, prepend, append):
287
287
)
288
288
if not prepend_shape :
289
289
prepend_shape = arr_shape [:axis ] + (1 ,) + arr_shape [axis + 1 :]
290
- a_prepend = dpt .broadcast_to (a_prepend , arr_shape )
290
+ a_prepend = dpt .broadcast_to (a_prepend , prepend_shape )
291
291
if not append_shape :
292
292
append_shape = arr_shape [:axis ] + (1 ,) + arr_shape [axis + 1 :]
293
- a_append = dpt .broadcast_to (a_append , arr_shape )
293
+ a_append = dpt .broadcast_to (a_append , append_shape )
294
294
return dpt .concat ((a_prepend , arr , a_append ), axis = axis )
295
295
elif prepend is not None :
296
296
q1 , x_usm_type = arr .sycl_queue , arr .usm_type
@@ -347,7 +347,7 @@ def _concat_diff_input(arr, axis, prepend, append):
347
347
)
348
348
if not prepend_shape :
349
349
prepend_shape = arr_shape [:axis ] + (1 ,) + arr_shape [axis + 1 :]
350
- a_prepend = dpt .broadcast_to (a_prepend , arr_shape )
350
+ a_prepend = dpt .broadcast_to (a_prepend , prepend_shape )
351
351
return dpt .concat ((a_prepend , arr ), axis = axis )
352
352
elif append is not None :
353
353
q1 , x_usm_type = arr .sycl_queue , arr .usm_type
@@ -402,7 +402,7 @@ def _concat_diff_input(arr, axis, prepend, append):
402
402
)
403
403
if not append_shape :
404
404
append_shape = arr_shape [:axis ] + (1 ,) + arr_shape [axis + 1 :]
405
- a_append = dpt .broadcast_to (a_append , arr_shape )
405
+ a_append = dpt .broadcast_to (a_append , append_shape )
406
406
return dpt .concat ((arr , a_append ), axis = axis )
407
407
else :
408
408
arr1 = arr
0 commit comments