Skip to content

Commit a0ff737

Browse files
committed
Correcting index test cas
1 parent 4ad2791 commit a0ff737

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1010
from torch_tensorrt.dynamo.conversion.converter_utils import (
1111
broadcastable,
12+
cast_trt_tensor,
1213
get_positive_dim,
1314
get_trt_tensor,
1415
to_numpy,
@@ -81,16 +82,16 @@ def gather(
8182
name: str,
8283
input: TRTTensor,
8384
dim: int,
84-
index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
85+
index: Union[TRTTensor, np.ndarray, torch.Tensor],
8586
sparse_grad: bool = False,
8687
) -> TRTTensor:
87-
indices_tensor = []
88-
89-
for i, ind in enumerate(index):
90-
indices_tensor.append(
91-
get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}")
92-
)
93-
gather_layer = ctx.net.add_gather(input, indices_tensor, dim)
88+
if not isinstance(index, TRTTensor):
89+
index = get_trt_tensor(ctx, index, name + f"_parameter_to_fp32_tensor")
90+
# This is for the case where torch.ops.aten.gather requires torch.int64
91+
# However TRTInterpreter complains that torch.int64 is not a supported type
92+
# So the below cast does not help
93+
# index = cast_trt_tensor(ctx, input, trt.int32, name, target, source_ir)
94+
gather_layer = ctx.net.add_gather(input, index, dim)
9495
set_layer_name(gather_layer, target, name + "_gather", source_ir)
9596
return gather_layer.get_output(0)
9697

tests/py/dynamo/conversion/test_gather_aten.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,25 @@
22

33
import torch
44
import torch.nn as nn
5-
from .harness import DispatchTestCase
65
from torch.testing._internal.common_utils import run_tests
76
from torch_tensorrt import Input
87

8+
from .harness import DispatchTestCase
9+
910

1011
class TestGatherConverter(DispatchTestCase):
1112
def test_gather_zero_two_dim(self):
1213
class TestModule(nn.Module):
1314
def __init__(self):
14-
# self.index0 = torch.randint(0, 1, (1, 1))
1515
super().__init__()
1616

1717
def forward(self, x, indices):
1818
# index0 = torch.randint(0, 1, (1, 1))
19-
# indices = [None, self.index0]
2019
out = torch.ops.aten.gather.default(x, 0, indices)
2120
return out
2221

23-
index0 = torch.randint(0, 1, (1, 1), dtype=torch.int32)
24-
indices = [None, index0]
22+
# index0 = torch.randint(0, 1, (1, 1), dtype=torch.int32)
23+
index0 = torch.randint(0, 1, (1, 1))
2524
input = [torch.randn(2, 2), index0]
2625
self.run_test(
2726
TestModule(),

0 commit comments

Comments
 (0)