@@ -118,23 +118,31 @@ def call_origin(function, *args, **kwargs):
118
118
# print(f"DPNP call_origin(): Fallback called. \n\t function={function}, \n\t args={args}, \n\t kwargs={kwargs}, \n\t dpnp_inplace={dpnp_inplace}")
119
119
120
120
kwargs_out = kwargs.get(" out" , None )
121
+ alloc_queues = []
121
122
if (kwargs_out is not None ):
122
123
if isinstance (kwargs_out, numpy.ndarray):
123
124
kwargs[" out" ] = kwargs_out
124
125
else :
126
+ if hasattr (kwargs_out, " sycl_queue" ):
127
+ alloc_queues.append(kwargs_out.sycl_queue)
125
128
kwargs[" out" ] = dpnp.asnumpy(kwargs_out)
126
129
127
130
args_new_list = []
128
131
for arg in args:
132
+ if hasattr (arg, " sycl_queue" ):
133
+ alloc_queues.append(arg.sycl_queue)
129
134
argx = convert_item(arg)
130
135
args_new_list.append(argx)
131
136
args_new = tuple (args_new_list)
132
137
133
138
kwargs_new = {}
134
139
for key, kwarg in kwargs.items():
140
+ if hasattr (kwarg, " sycl_queue" ):
141
+ alloc_queues.append(kwarg.sycl_queue)
135
142
kwargx = convert_item(kwarg)
136
143
kwargs_new[key] = kwargx
137
144
145
+ exec_q = dpctl.utils.get_execution_queue(alloc_queues)
138
146
# print(f"DPNP call_origin(): bakend called. \n\t function={function}, \n\t args_new={args_new}, \n\t kwargs_new={kwargs_new}, \n\t dpnp_inplace={dpnp_inplace}")
139
147
# TODO need to put array memory into NumPy call
140
148
result_origin = function(* args_new, ** kwargs_new)
@@ -157,7 +165,7 @@ def call_origin(function, *args, **kwargs):
157
165
if (kwargs_dtype is not None ):
158
166
result_dtype = kwargs_dtype
159
167
160
- result = dpnp_container.empty(result_origin.shape, dtype = result_dtype)
168
+ result = dpnp_container.empty(result_origin.shape, dtype = result_dtype, sycl_queue = exec_q )
161
169
else :
162
170
result = kwargs_out
163
171
@@ -169,7 +177,7 @@ def call_origin(function, *args, **kwargs):
169
177
for res_origin in result:
170
178
res = res_origin
171
179
if isinstance (res_origin, numpy.ndarray):
172
- res = dpnp_container.empty(res_origin.shape, dtype = res_origin.dtype)
180
+ res = dpnp_container.empty(res_origin.shape, dtype = res_origin.dtype, sycl_queue = exec_q )
173
181
copy_from_origin(res, res_origin)
174
182
result_list.append(res)
175
183
0 commit comments