Skip to content

Commit fe4c862

Browse files
committed
Add test to explicitly cover the w/a for gemm and gemm_batch
1 parent 11d3a94 commit fe4c862

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

tests/test_mathematical.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3824,6 +3824,32 @@ def test_matmul_alias(self):
38243824
result2 = dpnp.linalg.matmul(a, b)
38253825
assert_array_equal(result1, result2)
38263826

3827+
@pytest.mark.parametrize(
3828+
"sh1, sh2",
3829+
[
3830+
((2, 3, 3), (3, 3)),
3831+
((3, 4, 4, 4), (4, 4, 4)),
3832+
],
3833+
ids=["gemm", "gemm_batch"],
3834+
)
3835+
def test_matmul_with_offsets(self, sh1, sh2):
3836+
size1, size2 = numpy.prod(sh1, dtype=int), numpy.prod(sh2, dtype=int)
3837+
a = numpy.random.randint(-5, 5, size1).reshape(sh1)
3838+
b = numpy.random.randint(-5, 5, size2).reshape(sh2)
3839+
ia, ib = dpnp.array(a), dpnp.array(b)
3840+
3841+
result = ia[1] @ ib
3842+
expected = a[1] @ b
3843+
assert_array_equal(result, expected)
3844+
3845+
result = ib @ ia[1]
3846+
expected = b @ a[1]
3847+
assert_array_equal(result, expected)
3848+
3849+
result = ia[1] @ ia[1]
3850+
expected = a[1] @ a[1]
3851+
assert_array_equal(result, expected)
3852+
38273853

38283854
class TestMatmulInvalidCases:
38293855
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)