@@ -97,8 +97,27 @@ def convert_list_args(input_list):
97
97
98
98
return result_list
99
99
100
+ def copy_from_origin (src , device = None , usm_type = " device" , sycl_queue = None ):
101
+ """ Copy result from origin."""
102
+ if not isinstance (src, numpy.ndarray):
103
+ return src
104
+
105
+ if src.size == 0 :
106
+ return dpnp_container.empty(src.shape,
107
+ dtype = src.dtype,
108
+ device = device,
109
+ usm_type = usm_type,
110
+ sycl_queue = sycl_queue)
111
+
112
+ array_obj = dpctl.tensor._copy_utils.from_numpy(src,
113
+ device = device,
114
+ usm_type = usm_type,
115
+ sycl_queue = sycl_queue)
100
116
101
- def copy_from_origin (dst , src ):
117
+ return dpnp.dpnp_array.dpnp_array(src.shape, buffer = array_obj)
118
+
119
+
120
+ def copy_from_origin_into (dst , src ):
102
121
""" Copy origin result to output result."""
103
122
if config.__DPNP_OUTPUT_DPCTL__ and hasattr (dst, " __sycl_usm_array_interface__" ):
104
123
if src.size:
@@ -145,7 +164,7 @@ def call_origin(function, *args, **kwargs):
145
164
if args and args_new:
146
165
arg, arg_new = args[0 ], args_new[0 ]
147
166
if isinstance (arg_new, numpy.ndarray):
148
- copy_from_origin (arg, arg_new)
167
+ copy_from_origin_into (arg, arg_new)
149
168
elif isinstance (arg_new, list ):
150
169
for i, val in enumerate (arg_new):
151
170
arg[i] = val
@@ -157,20 +176,18 @@ def call_origin(function, *args, **kwargs):
157
176
if (kwargs_dtype is not None ):
158
177
result_dtype = kwargs_dtype
159
178
160
- result = dpnp_container.empty (result_origin.shape, dtype = result_dtype )
179
+ result = copy_from_origin (result_origin)
161
180
else :
162
181
result = kwargs_out
163
-
164
- copy_from_origin(result, result_origin)
182
+ copy_from_origin_into(result, result_origin)
165
183
166
184
elif isinstance (result, tuple ):
167
185
# convert tuple(fallback_array) to tuple(result_array)
168
186
result_list = []
169
187
for res_origin in result:
170
188
res = res_origin
171
189
if isinstance (res_origin, numpy.ndarray):
172
- res = dpnp_container.empty(res_origin.shape, dtype = res_origin.dtype)
173
- copy_from_origin(res, res_origin)
190
+ res = copy_from_origin(res_origin)
174
191
result_list.append(res)
175
192
176
193
result = tuple (result_list)
0 commit comments