Skip to content

Commit 0393595

Browse files
Arm backend: Remove duplicted Matmul-tests
torch.matmul was tested in both test_bmm and test_matmul. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Ia84384071e6d27a7c8b171cfc79aba0b4b1e2422
1 parent 6ad47df commit 0393595

File tree

1 file changed

+0
-22
lines changed

1 file changed

+0
-22
lines changed

backends/arm/test/ops/test_bmm.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,6 @@ def forward(self, x, y):
4444
return torch.bmm(x, y)
4545

4646

47-
class MatMul(torch.nn.Module):
48-
test_data_generators = {
49-
"rand_3d": lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)),
50-
"rand_4d": lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)),
51-
}
52-
53-
def forward(self, x, y):
54-
return torch.matmul(x, y)
55-
56-
5747
class BMMSingleInput(torch.nn.Module):
5848
test_data_generators = {
5949
"rand_3d_1": lambda: (torch.rand(20, 3, 3),),
@@ -81,18 +71,6 @@ def test_bmm_tosa_MI_single_input(test_data: input_t1):
8171
pipeline.run()
8272

8373

84-
@common.parametrize("test_data", MatMul.test_data_generators)
85-
def test_mm_tosa_MI(test_data: input_t1):
86-
pipeline = TosaPipelineMI[input_t1](MatMul(), test_data(), aten_op_mm, exir_op_mm)
87-
pipeline.run()
88-
89-
90-
@common.parametrize("test_data", MatMul.test_data_generators)
91-
def test_mm_tosa_BI(test_data: input_t1):
92-
pipeline = TosaPipelineBI[input_t1](MatMul(), test_data(), aten_op_mm, exir_op_mm)
93-
pipeline.run()
94-
95-
9674
@pytest.mark.flaky(reruns=5) # TODO: Investigate flakyness (MLETORCH-534)
9775
@common.parametrize("test_data", BMM.test_data_generators)
9876
def test_bmm_tosa_BI(test_data: input_t1):

0 commit comments

Comments
 (0)