Skip to content

Commit da62d5f

Browse files
Arm backend: Clean up matmul tests (#10971)
- Removes duplicated matmul tests. - Replaces pytest.mark_flaky with qtol for quantized tests cases of mm/bmm. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 08dfe52 commit da62d5f

File tree

2 files changed

+4
-26
lines changed

2 files changed

+4
-26
lines changed

backends/arm/test/ops/test_bmm.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,6 @@ def forward(self, x, y):
4444
return torch.bmm(x, y)
4545

4646

47-
class MatMul(torch.nn.Module):
48-
test_data_generators = {
49-
"rand_3d": lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)),
50-
"rand_4d": lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)),
51-
}
52-
53-
def forward(self, x, y):
54-
return torch.matmul(x, y)
55-
56-
5747
class BMMSingleInput(torch.nn.Module):
5848
test_data_generators = {
5949
"rand_3d_1": lambda: (torch.rand(20, 3, 3),),
@@ -81,26 +71,14 @@ def test_bmm_tosa_MI_single_input(test_data: input_t1):
8171
pipeline.run()
8272

8373

84-
@common.parametrize("test_data", MatMul.test_data_generators)
85-
def test_mm_tosa_MI(test_data: input_t1):
86-
pipeline = TosaPipelineMI[input_t1](MatMul(), test_data(), aten_op_mm, exir_op_mm)
87-
pipeline.run()
88-
89-
90-
@common.parametrize("test_data", MatMul.test_data_generators)
91-
def test_mm_tosa_BI(test_data: input_t1):
92-
pipeline = TosaPipelineBI[input_t1](MatMul(), test_data(), aten_op_mm, exir_op_mm)
93-
pipeline.run()
94-
95-
96-
@pytest.mark.flaky(reruns=5) # TODO: Investigate flakyness (MLETORCH-534)
9774
@common.parametrize("test_data", BMM.test_data_generators)
9875
def test_bmm_tosa_BI(test_data: input_t1):
99-
pipeline = TosaPipelineBI[input_t1](BMM(), test_data(), aten_op_bmm, exir_op_bmm)
76+
pipeline = TosaPipelineBI[input_t1](
77+
BMM(), test_data(), aten_op_bmm, exir_op_bmm, qtol=1
78+
)
10079
pipeline.run()
10180

10281

103-
@pytest.mark.flaky(reruns=5) # TODO: Investigate flakyness (MLETORCH-534)
10482
@common.parametrize("test_data", BMMSingleInput.test_data_generators)
10583
def test_bmm_tosa_BI_single_input(test_data: input_t1):
10684
pipeline = TosaPipelineBI[input_t1](

backends/arm/test/ops/test_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_mm_tosa_MI(test_data: Tuple):
4141

4242
@common.parametrize("test_data", MM.test_data_generators)
4343
def test_mm_tosa_BI(test_data: Tuple):
44-
TosaPipelineBI[test_t](MM(), test_data(), MM.aten_op, MM.exir_op).run()
44+
TosaPipelineBI[test_t](MM(), test_data(), MM.aten_op, MM.exir_op, qtol=1).run()
4545

4646

4747
@common.parametrize("test_data", MM.test_data_generators)

0 commit comments

Comments
 (0)