Skip to content

Commit 4f742a9

Browse files
committed
Modifications to matmul and select tests
1 parent 8c8e897 commit 4f742a9

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def test_matmul(self):
1111
class TestModule(torch.nn.Module):
1212
def forward(self, x, y):
1313
return torch.matmul(x, y)
14-
inputOne = torch.randn(1, 32)
14+
inputOne = torch.randn(3, 32)
1515
inputTwo = torch.randn(32, 3)
1616
inputs = [inputOne, inputTwo]
1717
self.run_test(

py/torch_tensorrt/fx/test/converters/aten_op/test_select_aten.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,25 @@
77
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
88

99
class TestSelectConverter(DispatchTestCase):
10-
def test_select(self):
10+
@parameterized.expand(
11+
[
12+
("select_dim_index", 2, 1),
13+
]
14+
)
15+
def test_select(self, dim_test, index_test):
1116
class TestModule(torch.nn.Module):
12-
def forward(self, input, dim, index):
13-
return torch.select(input, dim, index)
17+
def __init__(self, dim, index):
18+
super().__init__()
19+
self.dim = dim
20+
self.index = index
21+
def forward(self, input):
22+
return torch.select(input, self.dim, self.index)
1423
input = [torch.randn(1, 3, 32)]
15-
dim = 2
16-
index = 1
17-
inputs = (input, dim, index)
1824
self.run_test(
19-
TestModule(), input, expected_ops={torch.ops.aten.select.Tensor}, test_explicit_precision=True,
25+
TestModule(dim_test, index_test), input, expected_ops={torch.ops.aten.select}, test_explicit_precision=True,
2026
)
2127

22-
def test_select_with_dynamic_shape(self, x, y):
28+
def test_select_with_dynamic_shape(self, dim_test, index_test):
2329
class TestModule(torch.nn.Module):
2430
def forward(self, input, dim, index):
2531
return torch.select(input, dim, index)
@@ -31,9 +37,9 @@ def forward(self, input, dim, index):
3137
shape_ranges=[((1, 3, 3), (3, 3, 3), (32, 32, 32))],
3238
),
3339
]
34-
dim = 2
35-
index = 1
36-
inputs_spec = (input_spec, dim, index)
3740
self.run_test_with_dynamic_shape(
38-
TestModule(), inputs_spec, expected_ops={torch.ops.aten.select.Tensor}
39-
)
41+
TestModule(dim_test, index_test), input_spec, expected_ops={torch.ops.aten.select}
42+
)
43+
44+
if __name__ == "__main__":
45+
run_tests()

0 commit comments

Comments
 (0)