Skip to content

Commit d600363

Browse files
committed
address more comments
1 parent 295e0a7 commit d600363

File tree

3 files changed

+129
-16
lines changed

3 files changed

+129
-16
lines changed

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
5858
and out.shape == shape
5959
and out.usm_type == usm_type
6060
and out.sycl_queue == sycl_queue
61-
and (out.flags.c_contiguous or out.flags.f_contiguous)
61+
and out.flags.c_contiguous
6262
and not ti._array_overlap(x1_usm, out_usm)
6363
and not ti._array_overlap(x2_usm, out_usm)
6464
):
@@ -417,7 +417,10 @@ def dpnp_matmul(
417417
if out is not None:
418418
dpnp.check_supported_arrays_type(out)
419419
# out that is passed to the backend should have the correct shape
420-
out = dpnp.moveaxis(out, axes_res, (-2, -1))
420+
if len(axes_res) == 2:
421+
out = dpnp.moveaxis(out, axes_res, (-2, -1))
422+
elif len(axes_res) == 1:
423+
out = dpnp.moveaxis(out, axes_res, (-1,))
421424

422425
appended_axes = []
423426
if x1_ndim == 1:
@@ -439,6 +442,37 @@ def dpnp_matmul(
439442
f"(size {x1_shape[-1]} is different from {x2_shape[-2]})"
440443
)
441444

445+
if out is not None:
446+
out_shape = out.shape
447+
if not appended_axes:
448+
if out_shape[-2] != x1_shape[-2]:
449+
raise ValueError(
450+
"Output array has a mismatch in its core dimension 0. "
451+
"The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
452+
f"(size {out_shape[-2]} is different from {x1_shape[-2]})"
453+
)
454+
if out_shape[-1] != x2_shape[-1]:
455+
raise ValueError(
456+
"Output array has a mismatch in its core dimension 1. "
457+
"The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
458+
f"(size {out_shape[-1]} is different from {x2_shape[-1]})"
459+
)
460+
elif len(appended_axes) == 1:
461+
if appended_axes[0] == -1:
462+
if out_shape[-1] != x1_shape[-2]:
463+
raise ValueError(
464+
"Output array has a mismatch in its core dimension 0. "
465+
"The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
466+
f"(size {out_shape[-1]} is different from {x1_shape[-2]})"
467+
)
468+
elif appended_axes[0] == -2:
469+
if out_shape[-1] != x2_shape[-1]:
470+
raise ValueError(
471+
"Output array has a mismatch in its core dimension 0. "
472+
"The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
473+
f"(size {out_shape[-1]} is different from {x2_shape[-1]})"
474+
)
475+
442476
# Determine the appropriate data types
443477
gemm_dtype, res_dtype = _op_res_dtype(
444478
x1, x2, dtype=dtype, casting=casting, sycl_queue=exec_q
@@ -483,10 +517,25 @@ def dpnp_matmul(
483517
x2 = dpnp.repeat(x2, x1_shape[i], axis=i)
484518
else:
485519
raise ValueError(
486-
"arrays could not be broadcast together with remapped shapes."
520+
"Input arrays could not be broadcast together with remapped shapes, "
521+
f"{x1_shape[:-2]} is different from {x2_shape[:-2]}."
487522
)
523+
488524
x1_shape = x1.shape
489525
x2_shape = x2.shape
526+
if out is not None:
527+
for i in range(x1_ndim - 2):
528+
if x1_shape[i] != out_shape[i]:
529+
if not appended_axes:
530+
raise ValueError(
531+
"Output array could not be broadcast together with remapped shapes, "
532+
f"{x1_shape[:-2]} is different from {out_shape[:-2]}."
533+
)
534+
elif len(appended_axes) == 1:
535+
raise ValueError(
536+
"Output array could not be broadcast together with remapped shapes, "
537+
f"{x1_shape[:-2]} is different from {out_shape[:-1]}."
538+
)
490539
res_shape = tuple(tmp_shape) + (x1_shape[-2], x2_shape[-1])
491540

492541
# handling a special case to provide a similar result to NumPy
@@ -559,6 +608,8 @@ def dpnp_matmul(
559608

560609
if appended_axes:
561610
result = dpnp.squeeze(result, tuple(appended_axes))
611+
if len(appended_axes) == 2 and out is not None:
612+
result = dpnp.tile(result, out.shape)
562613

563614
if x1_is_2D and x2_is_2D:
564615
# add new axes only if one of the input arrays
@@ -572,7 +623,7 @@ def dpnp_matmul(
572623

573624
if out is None:
574625
if axes is not None:
575-
# Move the result to the appropriate axes of out array
626+
# Move the data to the appropriate axes of the result array
576627
if len(axes_res) == 2:
577628
result = dpnp.moveaxis(result, (-2, -1), axes_res)
578629
elif len(axes_res) == 1:
@@ -586,8 +637,7 @@ def dpnp_matmul(
586637
return result
587638
else:
588639
result = dpnp.get_result_array(result, out, casting=casting)
589-
if axes is not None:
590-
if out is result:
591-
# out and out_orig contain the same data but they have different shape
592-
return out_orig
640+
if axes is not None and out is result:
641+
# out and out_orig contain the same data but they have different shape
642+
return out_orig
593643
return result

tests/test_mathematical.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2729,6 +2729,29 @@ def test_matmul_axes_out(self, dtype, axes, out_shape):
27292729
# TODO: investigate the effect of factor, see SAT-6700
27302730
assert_dtype_allclose(result, expected, factor=24)
27312731

2732+
@pytest.mark.parametrize(
2733+
"axes, b_shape, out_shape",
2734+
[
2735+
([(1, 0), 0, 0], (3,), (4, 5)),
2736+
([(1, 0), 0, 1], (3,), (5, 4)),
2737+
([(1, 0), (0, 1), (1, 2)], (3, 1), (5, 4, 1)),
2738+
([(1, 0), (0, 1), (0, 2)], (3, 1), (4, 5, 1)),
2739+
([(1, 0), (0, 1), (1, 0)], (3, 1), (1, 4, 5)),
2740+
],
2741+
)
2742+
def test_matmul_axes_out_1D(self, axes, b_shape, out_shape):
2743+
a = numpy.arange(3 * 4 * 5).reshape(3, 4, 5)
2744+
b = numpy.arange(3).reshape(b_shape)
2745+
ia = dpnp.array(a)
2746+
ib = dpnp.array(b)
2747+
2748+
out_dp = dpnp.empty(out_shape)
2749+
out_np = numpy.empty(out_shape)
2750+
result = dpnp.matmul(ia, ib, axes=axes, out=out_dp)
2751+
assert result is out_dp
2752+
expected = numpy.matmul(a, b, axes=axes, out=out_np)
2753+
assert_dtype_allclose(result, expected)
2754+
27322755
@pytest.mark.parametrize("dtype1", get_all_dtypes(no_bool=True))
27332756
@pytest.mark.parametrize(
27342757
"dtype2", get_all_dtypes(no_bool=True, no_none=True)
@@ -2855,6 +2878,25 @@ def test_matmul_out(self, dtype):
28552878
assert result is dpnp_out
28562879
assert_dtype_allclose(result, expected)
28572880

2881+
@pytest.mark.parametrize(
2882+
"out_shape",
2883+
[
2884+
((4, 5)),
2885+
((6,)),
2886+
((4, 7, 2)),
2887+
],
2888+
)
2889+
def test_matmul_out_0D(self, out_shape):
2890+
a = numpy.arange(3)
2891+
b = dpnp.asarray(a)
2892+
2893+
numpy_out = numpy.empty(out_shape)
2894+
dpnp_out = dpnp.empty(out_shape)
2895+
result = dpnp.matmul(b, b, out=dpnp_out)
2896+
expected = numpy.matmul(a, a, out=numpy_out)
2897+
assert result is dpnp_out
2898+
assert_dtype_allclose(result, expected)
2899+
28582900

28592901
class TestMatmulInvalidCases:
28602902
@pytest.mark.parametrize(
@@ -2892,6 +2934,27 @@ def test_invalid_shape(self, shape_pair):
28922934
with pytest.raises(ValueError):
28932935
xp.matmul(x1, x2)
28942936

2937+
@pytest.mark.parametrize(
2938+
"shape_pair",
2939+
[
2940+
((5, 4, 3), (3, 1), (3, 4, 1)),
2941+
((5, 4, 3), (3, 1), (5, 6, 1)),
2942+
((5, 4, 3), (3, 1), (5, 4, 2)),
2943+
((5, 4, 3), (3,), (5, 3)),
2944+
((5, 4, 3), (3,), (6, 4)),
2945+
((3,), (3, 4, 5), (3, 5)),
2946+
((3,), (3, 4, 5), (4, 6)),
2947+
],
2948+
)
2949+
def test_invalid_shape_out(self, shape_pair):
2950+
for xp in (numpy, dpnp):
2951+
shape1, shape2, out_shape = shape_pair
2952+
x1 = xp.arange(numpy.prod(shape1), dtype=xp.float32).reshape(shape1)
2953+
x2 = xp.arange(numpy.prod(shape2), dtype=xp.float32).reshape(shape2)
2954+
res = xp.empty(out_shape)
2955+
with pytest.raises(ValueError):
2956+
xp.matmul(x1, x2, out=res)
2957+
28952958
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)[:-2])
28962959
def test_invalid_dtype(self, dtype):
28972960
dpnp_dtype = get_all_dtypes(no_none=True)[-1]

tests/third_party/cupy/math_tests/test_matmul.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,11 @@ class TestMatmulLarge(unittest.TestCase):
182182
rtol=1e-3, atol=1e-3, type_check=has_support_aspect64()
183183
) # required for uint8
184184
def test_operator_matmul(self, xp, dtype1, dtype2):
185-
if (dtype1, dtype1) in self.skip_dtypes or (
186-
dtype1,
185+
if (dtype1, dtype2) in self.skip_dtypes or (
186+
dtype2,
187187
dtype1,
188188
) in self.skip_dtypes:
189-
return xp.array([])
189+
pytest.skip()
190190
x1 = testing.shaped_random(self.shape_pair[0], xp, dtype1)
191191
x2 = testing.shaped_random(self.shape_pair[1], xp, dtype2)
192192
return operator.matmul(x1, x2)
@@ -197,11 +197,11 @@ def test_operator_matmul(self, xp, dtype1, dtype2):
197197
rtol=1e-3, atol=1e-3, type_check=has_support_aspect64()
198198
) # required for uint8
199199
def test_cupy_matmul(self, xp, dtype1, dtype2):
200-
if (dtype1, dtype1) in self.skip_dtypes or (
201-
dtype1,
200+
if (dtype1, dtype2) in self.skip_dtypes or (
201+
dtype2,
202202
dtype1,
203203
) in self.skip_dtypes:
204-
return xp.array([])
204+
pytest.skip()
205205
shape1, shape2 = self.shape_pair
206206
x1 = testing.shaped_random(shape1, xp, dtype1)
207207
x2 = testing.shaped_random(shape2, xp, dtype2)
@@ -211,7 +211,7 @@ def test_cupy_matmul(self, xp, dtype1, dtype2):
211211
@pytest.mark.parametrize(
212212
"shape1, shape2",
213213
[
214-
# TODO: include it when issue #1540 in dpctl is resolved
214+
# the first one causes overflow which is undefined behavior
215215
# ((256, 256, 3, 2), (256, 256, 2, 4)),
216216
((256, 256, 3, 2), (2, 4)),
217217
((3, 2), (256, 256, 2, 4)),
@@ -233,7 +233,7 @@ def test_cupy_matmul(self, xp, dtype, shape1, shape2):
233233
return xp.matmul(x1, x2)
234234

235235

236-
@pytest.mark.skip("until issue #1540 in dpctl is resolved")
236+
@pytest.mark.skip("overflow is undefined behavior.")
237237
class TestMatmulOverflow(unittest.TestCase):
238238
@testing.for_int_dtypes(name="dtype", no_bool=True)
239239
@testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8

0 commit comments

Comments
 (0)