Skip to content

Commit 5501ace

Browse files
committed
small fix: Index validator enable int64
- Repair test case
1 parent e38a7f3 commit 5501ace

File tree

2 files changed

+3
-8
lines changed

2 files changed

+3
-8
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def index_dtype_validator(node: Node) -> bool:
397397
for ind in index:
398398
if ind is not None:
399399
val = ind.meta.get("val")
400-
if val is not None and val.dtype != torch.int32:
400+
if val is not None and val.dtype not in (torch.int32, torch.int64):
401401
return False
402402
return True
403403

tests/py/dynamo/conversion/test_index_aten.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import operator
2-
31
import torch
42
import torch.nn as nn
5-
from .harness import DispatchTestCase
63
from torch.testing._internal.common_utils import run_tests
7-
from torch_tensorrt import Input
4+
5+
from .harness import DispatchTestCase
86

97

108
class TestIndexConverter(DispatchTestCase):
@@ -15,7 +13,6 @@ def __init__(self):
1513
super().__init__()
1614

1715
def forward(self, x):
18-
index0 = torch.randint(0, 1, (1, 1))
1916
indices = [None, self.index0]
2017
out = torch.ops.aten.index.Tensor(x, indices)
2118
return out
@@ -158,8 +155,6 @@ def __init__(self):
158155
super().__init__()
159156

160157
def forward(self, x):
161-
index0 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7])
162-
index1 = index0.unsqueeze(0).T.long()
163158
indices = [None, None, self.index0, self.index1]
164159
out = torch.ops.aten.index.Tensor(x, indices)
165160
return out

0 commit comments

Comments
 (0)