Skip to content

Commit 81f40a3

Browse files
committed
Index ITensor test
1 parent 563ca81 commit 81f40a3

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

tests/py/dynamo/conversion/test_index_aten.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22

33
import torch
44
import torch.nn as nn
5+
from harness import DispatchTestCase
56
from torch.testing._internal.common_utils import run_tests
67
from torch_tensorrt import Input
78

8-
from .harness import DispatchTestCase
9-
109

1110
class TestIndexConverter(DispatchTestCase):
1211
def test_index_zero_two_dim(self):
@@ -27,6 +26,21 @@ def forward(self, x):
2726
input,
2827
)
2928

29+
def test_index_zero_two_dim_ITensor(self):
30+
class TestModule(nn.Module):
31+
def forward(self, x, index0):
32+
indices = [None, index0]
33+
out = torch.ops.aten.index.Tensor(x, indices)
34+
return out
35+
36+
input = torch.randn(2, 2)
37+
index0 = torch.randint(0, 1, (1, 1))
38+
index0 = index0.to(torch.int32)
39+
self.run_test(
40+
TestModule(),
41+
[input, index0],
42+
)
43+
3044
def test_index_zero_index_three_dim(self):
3145
class TestModule(nn.Module):
3246
def __init__(self):
@@ -44,6 +58,18 @@ def forward(self, x):
4458
input,
4559
)
4660

61+
def test_index_zero_index_three_dim_ITensor(self):
62+
class TestModule(nn.Module):
63+
def forward(self, x, index0):
64+
indices = [None, index0, None]
65+
out = torch.ops.aten.index.Tensor(x, indices)
66+
return out
67+
68+
input = torch.randn(2, 2, 2)
69+
index0 = torch.randint(0, 1, (1, 1))
70+
index0 = index0.to(torch.int32)
71+
self.run_test(TestModule(), [input, index0])
72+
4773
def test_index_zero_index_one_index_two_three_dim(self):
4874
class TestModule(nn.Module):
4975
def __init__(self):

0 commit comments

Comments
 (0)