|
36 | 36 | __all__ = ["dpnp_dot", "dpnp_matmul"]
|
37 | 37 |
|
38 | 38 |
|
| 39 | +def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue): |
| 40 | + """ |
| 41 | + Create the result array. |
| 42 | +
|
| 43 | + If `out` is not ``None`` and its features match the specified `shape`, `dtype, |
| 44 | + `usm_type`, and `sycl_queue` and it is C-contiguous or F-contiguous and |
| 45 | + does not have any memory overlap with `x1` and `x2`, `out` itself is returned. |
| 46 | + If these conditions are not statisfied, an empty array is returned with the |
| 47 | + specified `shape`, `dtype, `usm_type`, and `sycl_queue`. |
| 48 | + """ |
| 49 | + |
| 50 | + if out is not None: |
| 51 | + x1_usm = dpnp.get_usm_ndarray(x1) |
| 52 | + x2_usm = dpnp.get_usm_ndarray(x2) |
| 53 | + out_usm = dpnp.get_usm_ndarray(out) |
| 54 | + |
| 55 | + if ( |
| 56 | + out.dtype == dtype |
| 57 | + and out.shape == shape |
| 58 | + and out.usm_type == usm_type |
| 59 | + and out.sycl_queue == sycl_queue |
| 60 | + and (out.flags.c_contiguous or out.flags.f_contiguous) |
| 61 | + and not ti._array_overlap(x1_usm, out_usm) |
| 62 | + and not ti._array_overlap(x2_usm, out_usm) |
| 63 | + ): |
| 64 | + return out |
| 65 | + |
| 66 | + return dpnp.empty( |
| 67 | + shape, |
| 68 | + dtype=dtype, |
| 69 | + usm_type=usm_type, |
| 70 | + sycl_queue=sycl_queue, |
| 71 | + ) |
| 72 | + |
| 73 | + |
39 | 74 | def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None):
|
40 | 75 | """
|
41 | 76 | Creating a copy of input array if needed.
|
@@ -214,14 +249,9 @@ def dpnp_dot(a, b, /, out=None, *, conjugate=False):
|
214 | 249 | a, b, dtype=None, casting="no", sycl_queue=exec_q
|
215 | 250 | )
|
216 | 251 |
|
217 |
| - # create result array |
218 |
| - result = dpnp.empty( |
219 |
| - (), |
220 |
| - dtype=dot_dtype, |
221 |
| - usm_type=res_usm_type, |
222 |
| - sycl_queue=exec_q, |
| 252 | + result = _create_result_array( |
| 253 | + a, b, out, (), dot_dtype, res_usm_type, exec_q |
223 | 254 | )
|
224 |
| - |
225 | 255 | # input arrays should have the proper data type
|
226 | 256 | dep_events_list = []
|
227 | 257 | host_tasks_list = []
|
@@ -367,13 +397,10 @@ def dpnp_matmul(
|
367 | 397 | x2_shape = x2.shape
|
368 | 398 | res_shape = tuple(tmp_shape) + (x1_shape[-2], x2_shape[-1])
|
369 | 399 |
|
370 |
| - # calculate results |
371 |
| - result = dpnp.empty( |
372 |
| - res_shape, |
373 |
| - dtype=gemm_dtype, |
374 |
| - usm_type=res_usm_type, |
375 |
| - sycl_queue=exec_q, |
| 400 | + result = _create_result_array( |
| 401 | + x1, x2, out, res_shape, gemm_dtype, res_usm_type, exec_q |
376 | 402 | )
|
| 403 | + # calculate result |
377 | 404 | if result.size == 0:
|
378 | 405 | pass
|
379 | 406 | elif x1.size == 0 or x2.size == 0:
|
|
0 commit comments