Skip to content

Commit dacb6b6

Browse files
committed
Test the output shape in test_matmul
1 parent 8d1ded4 commit dacb6b6

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

array_api_tests/test_linalg.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
mutually_promotable_dtypes)
2626
from .pytest_helpers import raises
2727

28+
from .test_broadcasting import broadcast_shapes
29+
2830
from . import _array_module
2931

3032
# Standin strategy for not yet implemented tests
@@ -259,6 +261,18 @@ def test_matmul(x1, x2):
259261
else:
260262
res = _array_module.linalg.matmul(x1, x2)
261263

264+
assert res.dtype == _array_module.result_type(x1, x2), "matmul() did not return the correct dtype"
265+
266+
if len(x1.shape) == len(x2.shape) == 1:
267+
assert res.shape == ()
268+
elif len(x1.shape) == 1:
269+
assert res.shape == x2.shape[:-2] + x2.shape[-1:]
270+
elif len(x2.shape) == 1:
271+
assert res.shape == x1.shape[:-1]
272+
else:
273+
stack_shape = broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
274+
assert res.shape == stack_shape + (x1.shape[-2], x2.shape[-1])
275+
262276
@given(
263277
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
264278
kw=kwargs(axis=todo, keepdims=todo, ord=todo)

0 commit comments

Comments
 (0)