@@ -193,6 +193,28 @@ def _asarray_from_usm_ndarray(
193
193
return res
194
194
195
195
196
+ def _map_to_device_dtype (dt , q ):
197
+ if dt .char == "?" or np .issubdtype (dt , np .integer ):
198
+ return dt
199
+ d = q .sycl_device
200
+ dtc = dt .char
201
+ if np .issubdtype (dt , np .floating ):
202
+ if dtc == "f" :
203
+ return dt
204
+ else :
205
+ if dtc == "d" and d .has_aspect_fp64 :
206
+ return dt
207
+ if dtc == "h" and d .has_aspect_fp16 :
208
+ return dt
209
+ return dpt .dtype ("f4" )
210
+ elif np .issubdtype (dt , np .complexfloating ):
211
+ if dtc == "F" :
212
+ return dt
213
+ if dtc == "D" and d .has_aspect_fp64 :
214
+ return dt
215
+ return dpt .dtype ("c8" )
216
+
217
+
196
218
def _asarray_from_numpy_ndarray (
197
219
ary , dtype = None , usm_type = None , sycl_queue = None , order = "K"
198
220
):
@@ -207,10 +229,8 @@ def _asarray_from_numpy_ndarray(
207
229
"Please convert the input to an array with numeric data type."
208
230
)
209
231
if dtype is None :
210
- ary_dtype = ary .dtype
211
- dtype = _get_dtype (dtype , copy_q , ref_type = ary_dtype )
212
- if dtype .itemsize > ary_dtype .itemsize or ary_dtype == np .uint64 :
213
- dtype = ary_dtype
232
+ # deduce device-representable output data type
233
+ dtype = _map_to_device_dtype (ary .dtype , copy_q )
214
234
f_contig = ary .flags ["F" ]
215
235
c_contig = ary .flags ["C" ]
216
236
fc_contig = f_contig or c_contig
@@ -246,7 +266,7 @@ def _asarray_from_numpy_ndarray(
246
266
order = order ,
247
267
buffer_ctor_kwargs = {"queue" : copy_q },
248
268
)
249
- ti . _copy_numpy_ndarray_into_usm_ndarray ( src = ary , dst = res , sycl_queue = copy_q )
269
+ res [...] = ary
250
270
return res
251
271
252
272
0 commit comments