Skip to content

Commit 28c0242

Browse files
committed
Start to implement test_matmul()
For now it only tests cases where it should or should not raise an exception.
1 parent 6c51306 commit 28c0242

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

array_api_tests/test_linalg.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
positive_definite_matrices, MAX_ARRAY_SIZE,
2424
invertible_matrices,
2525
mutually_promotable_dtypes)
26+
from .pytest_helpers import raises
2627

2728
from . import _array_module
2829

@@ -246,8 +247,18 @@ def test_inv(x):
246247
x2=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
247248
)
248249
def test_matmul(x1, x2):
249-
# res = _array_module.linalg.matmul(x1, x2)
250-
pass
250+
if (x1.shape == () or x2.shape == ()
251+
or len(x1.shape) == len(x2.shape) == 1 and x1.shape != x2.shape
252+
or len(x1.shape) == 1 and len(x2.shape) >= 2 and x1.shape[0] != x2.shape[-2]
253+
or len(x2.shape) == 1 and len(x1.shape) >= 2 and x2.shape[0] != x1.shape[-1]
254+
or len(x1.shape) >= 2 and len(x2.shape) >= 2 and x1.shape[-1] != x2.shape[-2]):
255+
# The spec doesn't specify what kind of exception is used here. Most
256+
# libraries will use a custom exception class.
257+
raises(Exception, lambda: _array_module.linalg.matmul(x1, x2),
258+
"matmul did not raise an exception for invalid shapes")
259+
return
260+
else:
261+
res = _array_module.linalg.matmul(x1, x2)
251262

252263
@given(
253264
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),

0 commit comments

Comments
 (0)