@@ -44,16 +44,6 @@ def forward(self, x, y):
44
44
return torch .bmm (x , y )
45
45
46
46
47
- class MatMul (torch .nn .Module ):
48
- test_data_generators = {
49
- "rand_3d" : lambda : (torch .rand (2 , 3 , 5 ), torch .rand (2 , 5 , 2 )),
50
- "rand_4d" : lambda : (torch .rand (1 , 2 , 3 , 5 ), torch .rand (1 , 2 , 5 , 2 )),
51
- }
52
-
53
- def forward (self , x , y ):
54
- return torch .matmul (x , y )
55
-
56
-
57
47
class BMMSingleInput (torch .nn .Module ):
58
48
test_data_generators = {
59
49
"rand_3d_1" : lambda : (torch .rand (20 , 3 , 3 ),),
@@ -81,26 +71,14 @@ def test_bmm_tosa_MI_single_input(test_data: input_t1):
81
71
pipeline .run ()
82
72
83
73
84
- @common .parametrize ("test_data" , MatMul .test_data_generators )
85
- def test_mm_tosa_MI (test_data : input_t1 ):
86
- pipeline = TosaPipelineMI [input_t1 ](MatMul (), test_data (), aten_op_mm , exir_op_mm )
87
- pipeline .run ()
88
-
89
-
90
- @common .parametrize ("test_data" , MatMul .test_data_generators )
91
- def test_mm_tosa_BI (test_data : input_t1 ):
92
- pipeline = TosaPipelineBI [input_t1 ](MatMul (), test_data (), aten_op_mm , exir_op_mm )
93
- pipeline .run ()
94
-
95
-
96
- @pytest .mark .flaky (reruns = 5 ) # TODO: Investigate flakyness (MLETORCH-534)
97
74
@common .parametrize ("test_data" , BMM .test_data_generators )
98
75
def test_bmm_tosa_BI (test_data : input_t1 ):
99
- pipeline = TosaPipelineBI [input_t1 ](BMM (), test_data (), aten_op_bmm , exir_op_bmm )
76
+ pipeline = TosaPipelineBI [input_t1 ](
77
+ BMM (), test_data (), aten_op_bmm , exir_op_bmm , qtol = 1
78
+ )
100
79
pipeline .run ()
101
80
102
81
103
- @pytest .mark .flaky (reruns = 5 ) # TODO: Investigate flakyness (MLETORCH-534)
104
82
@common .parametrize ("test_data" , BMMSingleInput .test_data_generators )
105
83
def test_bmm_tosa_BI_single_input (test_data : input_t1 ):
106
84
pipeline = TosaPipelineBI [input_t1 ](
0 commit comments