File tree Expand file tree Collapse file tree 1 file changed +14
-0
lines changed Expand file tree Collapse file tree 1 file changed +14
-0
lines changed Original file line number Diff line number Diff line change 25
25
mutually_promotable_dtypes )
26
26
from .pytest_helpers import raises
27
27
28
+ from .test_broadcasting import broadcast_shapes
29
+
28
30
from . import _array_module
29
31
30
32
# Standin strategy for not yet implemented tests
@@ -259,6 +261,18 @@ def test_matmul(x1, x2):
259
261
else :
260
262
res = _array_module .linalg .matmul (x1 , x2 )
261
263
264
+ assert res .dtype == _array_module .result_type (x1 , x2 ), "matmul() did not return the correct dtype"
265
+
266
+ if len (x1 .shape ) == len (x2 .shape ) == 1 :
267
+ assert res .shape == ()
268
+ elif len (x1 .shape ) == 1 :
269
+ assert res .shape == x2 .shape [:- 2 ] + x2 .shape [- 1 :]
270
+ elif len (x2 .shape ) == 1 :
271
+ assert res .shape == x1 .shape [:- 1 ]
272
+ else :
273
+ stack_shape = broadcast_shapes (x1 .shape [:- 2 ], x2 .shape [:- 2 ])
274
+ assert res .shape == stack_shape + (x1 .shape [- 2 ], x2 .shape [- 1 ])
275
+
262
276
@given (
263
277
x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
264
278
kw = kwargs (axis = todo , keepdims = todo , ord = todo )
You can’t perform that action at this time.
0 commit comments