Skip to content

Commit d45bb24

Browse files
authored
Improve performance of dpnp.matmul and dpnp.dot with out keyword (#1694)
* use out keyword for result * fix strided or overlapping out * address comments * fix typo * remove additional check
1 parent 1a3866e commit d45bb24

File tree

2 files changed

+52
-22
lines changed

2 files changed

+52
-22
lines changed

dpnp/dpnp_iface.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -495,17 +495,20 @@ def get_result_array(a, out=None, casting="safe"):
495495
if out is None:
496496
return a
497497
else:
498-
dpnp.check_supported_arrays_type(out)
499-
if out.shape != a.shape:
500-
raise ValueError(
501-
f"Output array of shape {a.shape} is needed, got {out.shape}."
502-
)
503-
elif isinstance(out, dpt.usm_ndarray):
504-
out = dpnp_array._create_from_usm_ndarray(out)
498+
if a is out:
499+
return out
500+
else:
501+
dpnp.check_supported_arrays_type(out)
502+
if out.shape != a.shape:
503+
raise ValueError(
504+
f"Output array of shape {a.shape} is needed, got {out.shape}."
505+
)
506+
elif isinstance(out, dpt.usm_ndarray):
507+
out = dpnp_array._create_from_usm_ndarray(out)
505508

506-
dpnp.copyto(out, a, casting=casting)
509+
dpnp.copyto(out, a, casting=casting)
507510

508-
return out
511+
return out
509512

510513

511514
def get_usm_ndarray(a):

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,41 @@
3636
__all__ = ["dpnp_dot", "dpnp_matmul"]
3737

3838

39+
def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
40+
"""
41+
Create the result array.
42+
43+
If `out` is not ``None`` and its features match the specified `shape`, `dtype,
44+
`usm_type`, and `sycl_queue` and it is C-contiguous or F-contiguous and
45+
does not have any memory overlap with `x1` and `x2`, `out` itself is returned.
46+
If these conditions are not statisfied, an empty array is returned with the
47+
specified `shape`, `dtype, `usm_type`, and `sycl_queue`.
48+
"""
49+
50+
if out is not None:
51+
x1_usm = dpnp.get_usm_ndarray(x1)
52+
x2_usm = dpnp.get_usm_ndarray(x2)
53+
out_usm = dpnp.get_usm_ndarray(out)
54+
55+
if (
56+
out.dtype == dtype
57+
and out.shape == shape
58+
and out.usm_type == usm_type
59+
and out.sycl_queue == sycl_queue
60+
and (out.flags.c_contiguous or out.flags.f_contiguous)
61+
and not ti._array_overlap(x1_usm, out_usm)
62+
and not ti._array_overlap(x2_usm, out_usm)
63+
):
64+
return out
65+
66+
return dpnp.empty(
67+
shape,
68+
dtype=dtype,
69+
usm_type=usm_type,
70+
sycl_queue=sycl_queue,
71+
)
72+
73+
3974
def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None):
4075
"""
4176
Creating a copy of input array if needed.
@@ -214,14 +249,9 @@ def dpnp_dot(a, b, /, out=None, *, conjugate=False):
214249
a, b, dtype=None, casting="no", sycl_queue=exec_q
215250
)
216251

217-
# create result array
218-
result = dpnp.empty(
219-
(),
220-
dtype=dot_dtype,
221-
usm_type=res_usm_type,
222-
sycl_queue=exec_q,
252+
result = _create_result_array(
253+
a, b, out, (), dot_dtype, res_usm_type, exec_q
223254
)
224-
225255
# input arrays should have the proper data type
226256
dep_events_list = []
227257
host_tasks_list = []
@@ -367,13 +397,10 @@ def dpnp_matmul(
367397
x2_shape = x2.shape
368398
res_shape = tuple(tmp_shape) + (x1_shape[-2], x2_shape[-1])
369399

370-
# calculate results
371-
result = dpnp.empty(
372-
res_shape,
373-
dtype=gemm_dtype,
374-
usm_type=res_usm_type,
375-
sycl_queue=exec_q,
400+
result = _create_result_array(
401+
x1, x2, out, res_shape, gemm_dtype, res_usm_type, exec_q
376402
)
403+
# calculate result
377404
if result.size == 0:
378405
pass
379406
elif x1.size == 0 or x2.size == 0:

0 commit comments

Comments
 (0)