Skip to content

Commit a3c00bc

Browse files
Closes gh-1325
``` In [1]: import dpctl.tensor as dpt In [2]: a = dpt.arange(10, dtype='int64') ...: b = dpt.arange(10, dtype='float32') In [3]: dpt.concat((a,b)) Out[3]: usm_ndarray([0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32) In [4]: _.sycl_device.name Out[4]: 'Intel(R) Graphics [0x9a49]' ```
1 parent e5785ca commit a3c00bc

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import dpctl.tensor._tensor_impl as ti
2626
import dpctl.utils as dputils
2727

28+
from ._type_utils import _to_device_supported_dtype
29+
2830
__doc__ = (
2931
"Implementation module for array manipulation "
3032
"functions in :module:`dpctl.tensor`"
@@ -504,8 +506,10 @@ def _arrays_validation(arrays, check_ndim=True):
504506
_supported_dtype(Xi.dtype for Xi in arrays)
505507

506508
res_dtype = X0.dtype
509+
dev = exec_q.sycl_device
507510
for i in range(1, n):
508511
res_dtype = np.promote_types(res_dtype, arrays[i])
512+
res_dtype = _to_device_supported_dtype(res_dtype, dev)
509513

510514
if check_ndim:
511515
for i in range(1, n):

0 commit comments

Comments
 (0)