7
7
from torch_tensorrt .fx .tools .common_fx2trt import DispatchTestCase , InputTensorSpec
8
8
9
9
class TestSelectConverter (DispatchTestCase ):
10
- def test_select (self ):
10
+ @parameterized .expand (
11
+ [
12
+ ("select_dim_index" , 2 , 1 ),
13
+ ]
14
+ )
15
+ def test_select (self , dim_test , index_test ):
11
16
class TestModule (torch .nn .Module ):
12
- def forward (self , input , dim , index ):
13
- return torch .select (input , dim , index )
17
+ def __init__ (self , dim , index ):
18
+ super ().__init__ ()
19
+ self .dim = dim
20
+ self .index = index
21
+ def forward (self , input ):
22
+ return torch .select (input , self .dim , self .index )
14
23
input = [torch .randn (1 , 3 , 32 )]
15
- dim = 2
16
- index = 1
17
- inputs = (input , dim , index )
18
24
self .run_test (
19
- TestModule (), input , expected_ops = {torch .ops .aten .select . Tensor }, test_explicit_precision = True ,
25
+ TestModule (dim_test , index_test ), input , expected_ops = {torch .ops .aten .select }, test_explicit_precision = True ,
20
26
)
21
27
22
- def test_select_with_dynamic_shape (self , x , y ):
28
+ def test_select_with_dynamic_shape (self , dim_test , index_test ):
23
29
class TestModule (torch .nn .Module ):
24
30
def forward (self , input , dim , index ):
25
31
return torch .select (input , dim , index )
@@ -31,9 +37,9 @@ def forward(self, input, dim, index):
31
37
shape_ranges = [((1 , 3 , 3 ), (3 , 3 , 3 ), (32 , 32 , 32 ))],
32
38
),
33
39
]
34
- dim = 2
35
- index = 1
36
- inputs_spec = (input_spec , dim , index )
37
40
self .run_test_with_dynamic_shape (
38
- TestModule (), inputs_spec , expected_ops = {torch .ops .aten .select .Tensor }
39
- )
41
+ TestModule (dim_test , index_test ), input_spec , expected_ops = {torch .ops .aten .select }
42
+ )
43
+
44
+ if __name__ == "__main__" :
45
+ run_tests ()
0 commit comments