Skip to content

Commit fe93c05

Browse files
vtavanaantonwolfy
andauthored
resolve gh-1871 (#1872)
* update returned result when out is defined with order F * address comments * add test for out keyword in einsum --------- Co-authored-by: Anton <[email protected]>
1 parent 38fd39d commit fe93c05

File tree

5 files changed

+191
-35
lines changed

5 files changed

+191
-35
lines changed

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,6 @@ def matmul(
821821
822822
"""
823823

824-
dpnp.check_supported_arrays_type(x1, x2)
825824
if subok is False:
826825
raise NotImplementedError(
827826
"subok keyword argument is only supported by its default value."

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 73 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import dpctl.tensor._tensor_elementwise_impl as tei
3434
import dpctl.tensor._tensor_impl as ti
3535
import numpy
36+
from dpctl.utils import ExecutionPlacementError
3637
from numpy.core.numeric import normalize_axis_tuple
3738

3839
import dpnp
@@ -218,7 +219,9 @@ def _compute_size(start, shape):
218219
return ret
219220

220221

221-
def _copy_array(x, dep_events, host_events, copy_flag=False, dtype=None):
222+
def _copy_array(
223+
x, dep_events, host_events, copy_flag=False, dtype=None, order="C"
224+
):
222225
"""
223226
Creating a copy of input array if needed.
224227
@@ -236,7 +239,7 @@ def _copy_array(x, dep_events, host_events, copy_flag=False, dtype=None):
236239
copy = x.dtype != dtype if dtype is not None else False
237240

238241
if copy:
239-
x_copy = dpnp.empty_like(x, dtype=dtype, order="C")
242+
x_copy = dpnp.empty_like(x, dtype=dtype, order=order)
240243
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
241244
src=dpnp.get_usm_ndarray(x),
242245
dst=x_copy.get_array(),
@@ -248,7 +251,9 @@ def _copy_array(x, dep_events, host_events, copy_flag=False, dtype=None):
248251
return x
249252

250253

251-
def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
254+
def _create_result_array(
255+
x1, x2, out, shape, dtype, usm_type, sycl_queue, order="C"
256+
):
252257
"""
253258
Create the result array.
254259
@@ -263,13 +268,12 @@ def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
263268
x1_usm = dpnp.get_usm_ndarray(x1)
264269
x2_usm = dpnp.get_usm_ndarray(x2)
265270
out_usm = dpnp.get_usm_ndarray(out)
266-
contig_flag = _define_contig_flag(out)
271+
contig_flag, _, _ = _define_contig_flag(out)
267272

268273
if (
269274
out.dtype == dtype
270275
and out.shape == shape
271276
and out.usm_type == usm_type
272-
and out.sycl_queue == sycl_queue
273277
and contig_flag
274278
and not ti._array_overlap(x1_usm, out_usm)
275279
and not ti._array_overlap(x2_usm, out_usm)
@@ -279,6 +283,7 @@ def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
279283
return dpnp.empty(
280284
shape,
281285
dtype=dtype,
286+
order=order,
282287
usm_type=usm_type,
283288
sycl_queue=sycl_queue,
284289
)
@@ -295,14 +300,14 @@ def _define_contig_flag(x):
295300
x_strides = x.strides
296301
x_shape = x.shape
297302
if x.ndim < 2:
298-
return True
303+
return True, True, True
299304

300305
x_strides = _standardize_strides_to_nonzero(x_strides, x_shape)
301306
x_is_c_contiguous = x_strides[-1] == 1 and x_strides[-2] == x_shape[-1]
302307
x_is_f_contiguous = x_strides[-2] == 1 and x_strides[-1] == x_shape[-2]
303308
if x_is_c_contiguous or x_is_f_contiguous:
304309
flag = True
305-
return flag
310+
return flag, x_is_c_contiguous, x_is_f_contiguous
306311

307312

308313
def _define_dim_flags(x, pos):
@@ -746,17 +751,26 @@ def _gemm_batch_matmul(exec_q, x1, x2, res, dev_tasks_list):
746751
)
747752
ht_tasks_list.append(ht_blas_ev)
748753
dpctl.SyclEvent.wait_for(ht_tasks_list)
754+
749755
res_shape = res.shape
750-
if not row_major:
751-
res = dpnp.reshape(
752-
res.ravel(), (batch_size, res_shape[2], res_shape[1])
753-
).transpose(0, 2, 1)
756+
_, res_is_c_contig, res_is_f_contig = _define_contig_flag(res)
757+
if row_major:
758+
if res_is_f_contig:
759+
res = dpnp.reshape(
760+
dpnp.ravel(res, order="F"),
761+
(res_shape[1], res_shape[2], batch_size),
762+
).transpose(2, 0, 1)
763+
else:
764+
if res_is_c_contig:
765+
res = dpnp.reshape(
766+
dpnp.ravel(res, order="C"),
767+
(batch_size, res_shape[2], res_shape[1]),
768+
).transpose(0, 2, 1)
754769

755770
if res_shape != orig_shape:
756771
res = res.reshape(orig_shape)
757772

758-
res = dpnp.ascontiguousarray(res)
759-
return res
773+
return dpnp.ascontiguousarray(res)
760774

761775

762776
def _gemm_matmul(exec_q, x1, x2, res, dev_tasks_list):
@@ -769,14 +783,16 @@ def _gemm_matmul(exec_q, x1, x2, res, dev_tasks_list):
769783
)
770784
ht_blas_ev.wait()
771785

772-
if not row_major:
773-
# TODO: investigate the possibility of defining result
774-
# array with "F" order for this case
775-
res = dpnp.ascontiguousarray(
776-
dpnp.reshape(res.ravel(), res.shape, order="F")
777-
)
786+
if row_major:
787+
if res.flags.f_contiguous is True:
788+
# read data in "F" order and write it in "C" order
789+
res = dpnp.reshape(dpnp.ravel(res, order="F"), res.shape, order="C")
790+
else:
791+
if res.flags.c_contiguous is True:
792+
# read data in "C" order and write it in "F" order
793+
res = dpnp.reshape(dpnp.ravel(res, order="C"), res.shape, order="F")
778794

779-
return res
795+
return dpnp.ascontiguousarray(res)
780796

781797

782798
def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
@@ -1746,6 +1762,13 @@ def dpnp_dot(a, b, /, out=None, *, conjugate=False):
17461762
)
17471763

17481764
res_usm_type, exec_q = get_usm_allocations([a, b])
1765+
if (
1766+
out is not None
1767+
and dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None
1768+
):
1769+
raise ExecutionPlacementError(
1770+
"Input and output allocation queues are not compatible"
1771+
)
17491772

17501773
# Determine the appropriate data types
17511774
dot_dtype, res_dtype = _compute_res_dtype(a, b, sycl_queue=exec_q)
@@ -1812,6 +1835,12 @@ def dpnp_einsum(
18121835
arrays.append(a)
18131836

18141837
res_usm_type, exec_q = get_usm_allocations(arrays)
1838+
if out is not None:
1839+
dpnp.check_supported_arrays_type(out)
1840+
if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None:
1841+
raise ExecutionPlacementError(
1842+
"Input and output allocation queues are not compatible"
1843+
)
18151844
result_dtype = dpnp.result_type(*arrays) if dtype is None else dtype
18161845
for id, a in enumerate(operands):
18171846
if dpnp.isscalar(a):
@@ -2056,10 +2085,17 @@ def dpnp_matmul(
20562085
20572086
"""
20582087

2059-
x1_ndim = x1.ndim
2060-
x2_ndim = x2.ndim
2088+
dpnp.check_supported_arrays_type(x1, x2)
20612089
res_usm_type, exec_q = get_usm_allocations([x1, x2])
2090+
if out is not None:
2091+
dpnp.check_supported_arrays_type(out)
2092+
if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None:
2093+
raise ExecutionPlacementError(
2094+
"Input and output allocation queues are not compatible"
2095+
)
20622096

2097+
x1_ndim = x1.ndim
2098+
x2_ndim = x2.ndim
20632099
if axes is not None:
20642100
axes = _validate_axes(x1, x2, axes)
20652101

@@ -2072,7 +2108,6 @@ def dpnp_matmul(
20722108
x2 = dpnp.moveaxis(x2, axes_x2, (-2, -1)) if x2_ndim != 1 else x2
20732109
out_orig = out
20742110
if out is not None:
2075-
dpnp.check_supported_arrays_type(out)
20762111
# out that is passed to the backend should have the correct shape
20772112
if len(axes_res) == 2:
20782113
out = dpnp.moveaxis(out, axes_res, (-2, -1))
@@ -2161,8 +2196,18 @@ def dpnp_matmul(
21612196
res = dpnp_dot(x1, x2, out=out)
21622197
res_shape = res.shape
21632198
else:
2199+
x1_contig_flag, _, x1_f = _define_contig_flag(x1)
2200+
x2_contig_flag, _, x2_f = _define_contig_flag(x2)
2201+
res_order = "F" if (x1_f and x2_f and call_flag == "gemm") else "C"
21642202
res = _create_result_array(
2165-
x1, x2, out, res_shape, compute_dtype, res_usm_type, exec_q
2203+
x1,
2204+
x2,
2205+
out,
2206+
res_shape,
2207+
compute_dtype,
2208+
res_usm_type,
2209+
exec_q,
2210+
res_order,
21662211
)
21672212

21682213
# calculate result
@@ -2175,21 +2220,21 @@ def dpnp_matmul(
21752220
# their base (last 2-dimensions) to be c-contiguous or f-contiguous
21762221
dep_events_list = []
21772222
host_tasks_list = []
2178-
contig_flag = _define_contig_flag(x1)
21792223
x1 = _copy_array(
21802224
x1,
21812225
dep_events_list,
21822226
host_tasks_list,
2183-
copy_flag=not contig_flag,
2227+
copy_flag=not x1_contig_flag,
21842228
dtype=compute_dtype,
2229+
order=res_order,
21852230
)
2186-
contig_flag = _define_contig_flag(x2)
21872231
x2 = _copy_array(
21882232
x2,
21892233
dep_events_list,
21902234
host_tasks_list,
2191-
copy_flag=not contig_flag,
2235+
copy_flag=not x2_contig_flag,
21922236
dtype=compute_dtype,
2237+
order=res_order,
21932238
)
21942239

21952240
if call_flag == "gemv":

tests/test_linalg.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,12 +613,28 @@ def test_einsum_trivial_cases(self):
613613
expected = numpy.einsum("i,i,i", b_np, b_np, b_np, optimize="greedy")
614614
assert_dtype_allclose(result, expected)
615615

616+
def test_einsum_out(self):
617+
a = inp.ones((5, 5))
618+
a_np = a.asnumpy()
619+
out = inp.empty((5,))
620+
out_np = out.asnumpy()
621+
result = inp.einsum("ii->i", a, out=out)
622+
assert result is out
623+
expected = numpy.einsum("ii->i", a_np, out=out_np)
624+
assert_dtype_allclose(result, expected)
625+
616626
def test_einsum_error(self):
617627
a = inp.ones((5, 5))
618628
# unknown keyword argument
619629
with pytest.raises(TypeError):
620630
inp.einsum("ii->i", a, copy=False)
621631

632+
a = inp.ones((5, 5))
633+
out = inp.empty((5,), sycl_queue=dpctl.SyclQueue())
634+
# inconsistent sycl_queue
635+
with pytest.raises(ExecutionPlacementError):
636+
inp.einsum("ii->i", a, out=out)
637+
622638
# unknown value for optimize keyword
623639
with pytest.raises(TypeError):
624640
inp.einsum("ii->i", a, optimize="average")

0 commit comments

Comments
 (0)