Skip to content

Commit 30aa861

Browse files
committed
Align expection text for inplace matrix multiplication
1 parent e682912 commit 30aa861

File tree

3 files changed

+24
-10
lines changed

3 files changed

+24
-10
lines changed

dpnp/dpnp_array.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
# *****************************************************************************
2626

2727
import dpctl.tensor as dpt
28+
from dpctl.tensor._numpy_helper import AxisError
2829

2930
import dpnp
3031

@@ -379,7 +380,14 @@ def __imatmul__(self, other):
379380
else:
380381
axes = [(-2, -1), (-2, -1), (-2, -1)]
381382

382-
dpnp.matmul(self, other, out=self, axes=axes)
383+
try:
384+
dpnp.matmul(self, other, out=self, axes=axes)
385+
except AxisError:
386+
# AxisError should indicate that the axes argument didn't work out
387+
# which should mean the second operand not being 2 dimensional.
388+
raise ValueError(
389+
"inplace matrix multiplication requires the first operand to have at least one and the second at least two dimensions."
390+
)
383391
return self
384392

385393
def __imod__(self, other):

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import dpctl.tensor._tensor_impl as ti
2929
import dpctl.utils as dpu
3030
import numpy
31-
from dpctl.tensor._numpy_helper import normalize_axis_tuple
31+
from dpctl.tensor._numpy_helper import AxisError, normalize_axis_tuple
3232
from dpctl.utils import ExecutionPlacementError
3333

3434
import dpnp
@@ -525,15 +525,15 @@ def _validate_internal(axes, i, ndim):
525525
)
526526

527527
if len(axes) != 1:
528-
raise ValueError(
528+
raise AxisError(
529529
f"Axes item {i} should be a tuple with a single element, or an integer."
530530
)
531531
else:
532532
iter = 2
533533
if not isinstance(axes, tuple):
534534
raise TypeError(f"Axes item {i} should be a tuple.")
535535
if len(axes) != 2:
536-
raise ValueError(
536+
raise AxisError(
537537
f"Axes item {i} should be a tuple with 2 elements."
538538
)
539539

@@ -563,7 +563,7 @@ def _validate_internal(axes, i, ndim):
563563

564564
if x1_ndim == 1 and x2_ndim == 1:
565565
if axes[2] != ():
566-
raise ValueError("Axes item 2 should be an empty tuple.")
566+
raise AxisError("Axes item 2 should be an empty tuple.")
567567
elif x1_ndim == 1 or x2_ndim == 1:
568568
axes[2] = _validate_internal(axes[2], 2, 1)
569569
else:

tests/test_mathematical.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4216,10 +4216,16 @@ def test_shapes(self, a_sh, b_sh):
42164216

42174217
expected = a @ b
42184218
if expected.shape != a_sh:
4219-
with pytest.raises(ValueError):
4219+
if len(b_sh) == 1:
4220+
# check the exception matches NumPy
4221+
match = "inplace matrix multiplication requires"
4222+
else:
4223+
match = None
4224+
4225+
with pytest.raises(ValueError, match=match):
42204226
a @= b
42214227

4222-
with pytest.raises(ValueError):
4228+
with pytest.raises(ValueError, match=match):
42234229
ia @= ib
42244230
else:
42254231
ia @= ib
@@ -4356,7 +4362,7 @@ def test_matmul_axes(self, xp):
43564362

43574363
# axes item should be a tuple with 2 elements
43584364
axes = [(3, 1), (2, 0), (0, 1, 2)]
4359-
with pytest.raises(ValueError):
4365+
with pytest.raises(AxisError):
43604366
xp.matmul(a1, a2, axes=axes)
43614367

43624368
# axes must be an integer
@@ -4367,7 +4373,7 @@ def test_matmul_axes(self, xp):
43674373
# axes item 2 should be an empty tuple
43684374
a = xp.arange(3)
43694375
axes = [0, 0, 0]
4370-
with pytest.raises(ValueError):
4376+
with pytest.raises(AxisError):
43714377
xp.matmul(a, a, axes=axes)
43724378

43734379
a = xp.arange(3 * 4 * 5).reshape(3, 4, 5)
@@ -4379,7 +4385,7 @@ def test_matmul_axes(self, xp):
43794385

43804386
# axes item should be a tuple with a single element, or an integer
43814387
axes = [(1, 0), (0), (0, 1)]
4382-
with pytest.raises(ValueError):
4388+
with pytest.raises(AxisError):
43834389
xp.matmul(a, b, axes=axes)
43844390

43854391

0 commit comments

Comments
 (0)