Skip to content

Commit dd16dd2

Browse files
committed
Update test to reproduce the exact issue
1 parent 06c66d4 commit dd16dd2

File tree

1 file changed

+6
-14
lines changed

1 file changed

+6
-14
lines changed

tests/test_mathematical.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3885,27 +3885,19 @@ def test_matmul_alias(self):
38853885
@pytest.mark.parametrize(
38863886
"sh1, sh2",
38873887
[
3888-
((2, 3, 3), (3, 3)),
3889-
((3, 4, 4, 4), (4, 4, 4)),
3888+
((2, 3, 3), (2, 3, 3)),
3889+
((3, 3, 3, 3), (3, 3, 3, 3)),
38903890
],
38913891
ids=["gemm", "gemm_batch"],
38923892
)
38933893
def test_matmul_with_offsets(self, sh1, sh2):
38943894
size1, size2 = numpy.prod(sh1, dtype=int), numpy.prod(sh2, dtype=int)
3895-
a = numpy.random.randint(-5, 5, size1).reshape(sh1)
3896-
b = numpy.random.randint(-5, 5, size2).reshape(sh2)
3895+
a = numpy.random.randint(-5, 5, size1).reshape(sh1).astype("f8")
3896+
b = numpy.random.randint(-5, 5, size2).reshape(sh2).astype("f8")
38973897
ia, ib = dpnp.array(a), dpnp.array(b)
38983898

3899-
result = ia[1] @ ib
3900-
expected = a[1] @ b
3901-
assert_array_equal(result, expected)
3902-
3903-
result = ib @ ia[1]
3904-
expected = b @ a[1]
3905-
assert_array_equal(result, expected)
3906-
3907-
result = ia[1] @ ia[1]
3908-
expected = a[1] @ a[1]
3899+
result = ia[1] @ ib[1]
3900+
expected = a[1] @ b[1]
39093901
assert_array_equal(result, expected)
39103902

39113903

0 commit comments

Comments
 (0)