Skip to content

Commit 92c2fae

Browse files
densmirnoleksandr-pavlyk
authored andcommitted
Normilize queue and device + change copying from numpy in call_origin
1 parent 9fb287c commit 92c2fae

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

dpnp/dpnp_container.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import dpctl.tensor as dpt
4444
from dpctl.tensor._device import normalize_queue_device
4545

46+
4647
if config.__DPNP_OUTPUT_DPCTL__:
4748
try:
4849
"""

dpnp/dpnp_utils/dpnp_algo_utils.pyx

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,27 @@ def convert_list_args(input_list):
9797

9898
return result_list
9999

100+
def copy_from_origin(src, device=None, usm_type="device", sycl_queue=None):
101+
"""Copy result from origin."""
102+
if not isinstance(src, numpy.ndarray):
103+
return src
104+
105+
if src.size == 0:
106+
return dpnp_container.empty(src.shape,
107+
dtype=src.dtype,
108+
device=device,
109+
usm_type=usm_type,
110+
sycl_queue=sycl_queue)
111+
112+
array_obj = dpctl.tensor._copy_utils.from_numpy(src,
113+
device=device,
114+
usm_type=usm_type,
115+
sycl_queue=sycl_queue)
100116

101-
def copy_from_origin(dst, src):
117+
return dpnp.dpnp_array.dpnp_array(src.shape, buffer=array_obj)
118+
119+
120+
def copy_from_origin_into(dst, src):
102121
"""Copy origin result to output result."""
103122
if config.__DPNP_OUTPUT_DPCTL__ and hasattr(dst, "__sycl_usm_array_interface__"):
104123
if src.size:
@@ -145,7 +164,7 @@ def call_origin(function, *args, **kwargs):
145164
if args and args_new:
146165
arg, arg_new = args[0], args_new[0]
147166
if isinstance(arg_new, numpy.ndarray):
148-
copy_from_origin(arg, arg_new)
167+
copy_from_origin_into(arg, arg_new)
149168
elif isinstance(arg_new, list):
150169
for i, val in enumerate(arg_new):
151170
arg[i] = val
@@ -157,20 +176,18 @@ def call_origin(function, *args, **kwargs):
157176
if (kwargs_dtype is not None):
158177
result_dtype = kwargs_dtype
159178

160-
result = dpnp_container.empty(result_origin.shape, dtype=result_dtype)
179+
result = copy_from_origin(result_origin)
161180
else:
162181
result = kwargs_out
163-
164-
copy_from_origin(result, result_origin)
182+
copy_from_origin_into(result, result_origin)
165183

166184
elif isinstance(result, tuple):
167185
# convert tuple(fallback_array) to tuple(result_array)
168186
result_list = []
169187
for res_origin in result:
170188
res = res_origin
171189
if isinstance(res_origin, numpy.ndarray):
172-
res = dpnp_container.empty(res_origin.shape, dtype=res_origin.dtype)
173-
copy_from_origin(res, res_origin)
190+
res = copy_from_origin(res_origin)
174191
result_list.append(res)
175192

176193
result = tuple(result_list)

0 commit comments

Comments
 (0)