@@ -3885,27 +3885,19 @@ def test_matmul_alias(self):
3885
3885
@pytest .mark .parametrize (
3886
3886
"sh1, sh2" ,
3887
3887
[
3888
- ((2 , 3 , 3 ), (3 , 3 )),
3889
- ((3 , 4 , 4 , 4 ), (4 , 4 , 4 )),
3888
+ ((2 , 3 , 3 ), (2 , 3 , 3 )),
3889
+ ((3 , 3 , 3 , 3 ), (3 , 3 , 3 , 3 )),
3890
3890
],
3891
3891
ids = ["gemm" , "gemm_batch" ],
3892
3892
)
3893
3893
def test_matmul_with_offsets (self , sh1 , sh2 ):
3894
3894
size1 , size2 = numpy .prod (sh1 , dtype = int ), numpy .prod (sh2 , dtype = int )
3895
- a = numpy .random .randint (- 5 , 5 , size1 ).reshape (sh1 )
3896
- b = numpy .random .randint (- 5 , 5 , size2 ).reshape (sh2 )
3895
+ a = numpy .random .randint (- 5 , 5 , size1 ).reshape (sh1 ). astype ( "f8" )
3896
+ b = numpy .random .randint (- 5 , 5 , size2 ).reshape (sh2 ). astype ( "f8" )
3897
3897
ia , ib = dpnp .array (a ), dpnp .array (b )
3898
3898
3899
- result = ia [1 ] @ ib
3900
- expected = a [1 ] @ b
3901
- assert_array_equal (result , expected )
3902
-
3903
- result = ib @ ia [1 ]
3904
- expected = b @ a [1 ]
3905
- assert_array_equal (result , expected )
3906
-
3907
- result = ia [1 ] @ ia [1 ]
3908
- expected = a [1 ] @ a [1 ]
3899
+ result = ia [1 ] @ ib [1 ]
3900
+ expected = a [1 ] @ b [1 ]
3909
3901
assert_array_equal (result , expected )
3910
3902
3911
3903
0 commit comments