Skip to content

Commit e5785ca

Browse files
Fixed bug in concat uncovered by array API tests
``` import dpctl.tensor as dpt x1 = dpt.full(tuple(), 77, dtype='u2') x2 = dpt.zeros(2, dtype='uint8')[dpt.newaxis, :] dpt.concat((x1, x2), axis=None) ``` The reason the exception was raised is that _copy_usm_ndarray_for_reshape which is used in the implementation of concat with axis=None requires both source and destination to have the same data-type.
1 parent 51e3f15 commit e5785ca

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,8 +554,13 @@ def _concat_axis_None(arrays):
554554
sycl_queue=exec_q,
555555
)
556556
else:
557+
src_ = array
558+
# _copy_usm_ndarray_for_reshape requires src and dst to have
559+
# the same data type
560+
if not array.dtype == res_dtype:
561+
src_ = dpt.astype(src_, res_dtype)
557562
hev, _ = ti._copy_usm_ndarray_for_reshape(
558-
src=array,
563+
src=src_,
559564
dst=res[fill_start:fill_end],
560565
shift=0,
561566
sycl_queue=exec_q,

0 commit comments

Comments
 (0)