Skip to content

Commit 404b587

Browse files
Modified call_origin to deduce result allocation queue from input allocation queues using CFD
1 parent c3a5cab commit 404b587

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

dpnp/dpnp_utils/dpnp_algo_utils.pyx

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,23 +118,31 @@ def call_origin(function, *args, **kwargs):
118118
# print(f"DPNP call_origin(): Fallback called. \n\t function={function}, \n\t args={args}, \n\t kwargs={kwargs}, \n\t dpnp_inplace={dpnp_inplace}")
119119

120120
kwargs_out = kwargs.get("out", None)
121+
alloc_queues = []
121122
if (kwargs_out is not None):
122123
if isinstance(kwargs_out, numpy.ndarray):
123124
kwargs["out"] = kwargs_out
124125
else:
126+
if hasattr(kwargs_out, "sycl_queue"):
127+
alloc_queues.append(kwargs_out.sycl_queue)
125128
kwargs["out"] = dpnp.asnumpy(kwargs_out)
126129

127130
args_new_list = []
128131
for arg in args:
132+
if hasattr(arg, "sycl_queue"):
133+
alloc_queues.append(arg.sycl_queue)
129134
argx = convert_item(arg)
130135
args_new_list.append(argx)
131136
args_new = tuple(args_new_list)
132137

133138
kwargs_new = {}
134139
for key, kwarg in kwargs.items():
140+
if hasattr(kwarg, "sycl_queue"):
141+
alloc_queues.append(kwarg.sycl_queue)
135142
kwargx = convert_item(kwarg)
136143
kwargs_new[key] = kwargx
137144

145+
exec_q = dpctl.utils.get_execution_queue(alloc_queues)
138146
# print(f"DPNP call_origin(): bakend called. \n\t function={function}, \n\t args_new={args_new}, \n\t kwargs_new={kwargs_new}, \n\t dpnp_inplace={dpnp_inplace}")
139147
# TODO need to put array memory into NumPy call
140148
result_origin = function(*args_new, **kwargs_new)
@@ -157,7 +165,7 @@ def call_origin(function, *args, **kwargs):
157165
if (kwargs_dtype is not None):
158166
result_dtype = kwargs_dtype
159167

160-
result = dpnp_container.empty(result_origin.shape, dtype=result_dtype)
168+
result = dpnp_container.empty(result_origin.shape, dtype=result_dtype, sycl_queue=exec_q)
161169
else:
162170
result = kwargs_out
163171

@@ -169,7 +177,7 @@ def call_origin(function, *args, **kwargs):
169177
for res_origin in result:
170178
res = res_origin
171179
if isinstance(res_origin, numpy.ndarray):
172-
res = dpnp_container.empty(res_origin.shape, dtype=res_origin.dtype)
180+
res = dpnp_container.empty(res_origin.shape, dtype=res_origin.dtype, sycl_queue=exec_q)
173181
copy_from_origin(res, res_origin)
174182
result_list.append(res)
175183

0 commit comments

Comments
 (0)