Skip to content

Commit eb9b652

Browse files
asarray to use mapping from ndarray data type to usm_ndarray type
This mapping taking into account device aspects.
1 parent 04d194c commit eb9b652

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

dpctl/tensor/_ctors.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,28 @@ def _asarray_from_usm_ndarray(
193193
return res
194194

195195

196+
def _map_to_device_dtype(dt, q):
197+
if dt.char == "?" or np.issubdtype(dt, np.integer):
198+
return dt
199+
d = q.sycl_device
200+
dtc = dt.char
201+
if np.issubdtype(dt, np.floating):
202+
if dtc == "f":
203+
return dt
204+
else:
205+
if dtc == "d" and d.has_aspect_fp64:
206+
return dt
207+
if dtc == "h" and d.has_aspect_fp16:
208+
return dt
209+
return dpt.dtype("f4")
210+
elif np.issubdtype(dt, np.complexfloating):
211+
if dtc == "F":
212+
return dt
213+
if dtc == "D" and d.has_aspect_fp64:
214+
return dt
215+
return dpt.dtype("c8")
216+
217+
196218
def _asarray_from_numpy_ndarray(
197219
ary, dtype=None, usm_type=None, sycl_queue=None, order="K"
198220
):
@@ -207,10 +229,8 @@ def _asarray_from_numpy_ndarray(
207229
"Please convert the input to an array with numeric data type."
208230
)
209231
if dtype is None:
210-
ary_dtype = ary.dtype
211-
dtype = _get_dtype(dtype, copy_q, ref_type=ary_dtype)
212-
if dtype.itemsize > ary_dtype.itemsize or ary_dtype == np.uint64:
213-
dtype = ary_dtype
232+
# deduce device-representable output data type
233+
dtype = _map_to_device_dtype(ary.dtype, copy_q)
214234
f_contig = ary.flags["F"]
215235
c_contig = ary.flags["C"]
216236
fc_contig = f_contig or c_contig
@@ -246,7 +266,7 @@ def _asarray_from_numpy_ndarray(
246266
order=order,
247267
buffer_ctor_kwargs={"queue": copy_q},
248268
)
249-
ti._copy_numpy_ndarray_into_usm_ndarray(src=ary, dst=res, sycl_queue=copy_q)
269+
res[...] = ary
250270
return res
251271

252272

0 commit comments

Comments
 (0)