2
2
3
3
import torch
4
4
import torch .nn as nn
5
+ from harness import DispatchTestCase
5
6
from torch .testing ._internal .common_utils import run_tests
6
7
from torch_tensorrt import Input
7
8
8
- from .harness import DispatchTestCase
9
-
10
9
11
10
class TestIndexConverter (DispatchTestCase ):
12
11
def test_index_zero_two_dim (self ):
@@ -27,6 +26,21 @@ def forward(self, x):
27
26
input ,
28
27
)
29
28
29
+ def test_index_zero_two_dim_ITensor (self ):
30
+ class TestModule (nn .Module ):
31
+ def forward (self , x , index0 ):
32
+ indices = [None , index0 ]
33
+ out = torch .ops .aten .index .Tensor (x , indices )
34
+ return out
35
+
36
+ input = torch .randn (2 , 2 )
37
+ index0 = torch .randint (0 , 1 , (1 , 1 ))
38
+ index0 = index0 .to (torch .int32 )
39
+ self .run_test (
40
+ TestModule (),
41
+ [input , index0 ],
42
+ )
43
+
30
44
def test_index_zero_index_three_dim (self ):
31
45
class TestModule (nn .Module ):
32
46
def __init__ (self ):
@@ -44,6 +58,18 @@ def forward(self, x):
44
58
input ,
45
59
)
46
60
61
+ def test_index_zero_index_three_dim_ITensor (self ):
62
+ class TestModule (nn .Module ):
63
+ def forward (self , x , index0 ):
64
+ indices = [None , index0 , None ]
65
+ out = torch .ops .aten .index .Tensor (x , indices )
66
+ return out
67
+
68
+ input = torch .randn (2 , 2 , 2 )
69
+ index0 = torch .randint (0 , 1 , (1 , 1 ))
70
+ index0 = index0 .to (torch .int32 )
71
+ self .run_test (TestModule (), [input , index0 ])
72
+
47
73
def test_index_zero_index_one_index_two_three_dim (self ):
48
74
class TestModule (nn .Module ):
49
75
def __init__ (self ):
0 commit comments