Skip to content

Commit 06c66d4

Browse files
committed
Add test to explicitly cover the w/a for gemm and gemm_batch
1 parent 93ea22e commit 06c66d4

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

tests/test_mathematical.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3882,6 +3882,32 @@ def test_matmul_alias(self):
38823882
result2 = dpnp.linalg.matmul(a, b)
38833883
assert_array_equal(result1, result2)
38843884

3885+
@pytest.mark.parametrize(
3886+
"sh1, sh2",
3887+
[
3888+
((2, 3, 3), (3, 3)),
3889+
((3, 4, 4, 4), (4, 4, 4)),
3890+
],
3891+
ids=["gemm", "gemm_batch"],
3892+
)
3893+
def test_matmul_with_offsets(self, sh1, sh2):
3894+
size1, size2 = numpy.prod(sh1, dtype=int), numpy.prod(sh2, dtype=int)
3895+
a = numpy.random.randint(-5, 5, size1).reshape(sh1)
3896+
b = numpy.random.randint(-5, 5, size2).reshape(sh2)
3897+
ia, ib = dpnp.array(a), dpnp.array(b)
3898+
3899+
result = ia[1] @ ib
3900+
expected = a[1] @ b
3901+
assert_array_equal(result, expected)
3902+
3903+
result = ib @ ia[1]
3904+
expected = b @ a[1]
3905+
assert_array_equal(result, expected)
3906+
3907+
result = ia[1] @ ia[1]
3908+
expected = a[1] @ a[1]
3909+
assert_array_equal(result, expected)
3910+
38853911

38863912
class TestMatmulInvalidCases:
38873913
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)