Skip to content

Commit 6090cf1

Browse files
committed
Correct sparse arg in aten::gather
1 parent c48a306 commit 6090cf1

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def aten_ops_gather(
204204
input=args[0],
205205
dim=args[1],
206206
index=args[2],
207-
sparse_grad=args_bounds_check(args, 4, False),
207+
sparse_grad=args_bounds_check(args, 3, False),
208208
)
209209

210210

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def gather(
8282
input: TRTTensor,
8383
dim: int,
8484
index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
85+
sparse_grad: bool = False,
8586
) -> TRTTensor:
8687
gather_layer = ctx.net.add_gather(input, index, dim)
8788
set_layer_name(gather_layer, target, name + "_gather", source_ir)

tests/py/dynamo/conversion/test_gather_aten.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,27 @@
22

33
import torch
44
import torch.nn as nn
5-
from harness import DispatchTestCase
5+
from .harness import DispatchTestCase
66
from torch.testing._internal.common_utils import run_tests
77
from torch_tensorrt import Input
88

99

10-
class TestIndexConverter(DispatchTestCase):
11-
def test_index_zero_two_dim(self):
10+
class TestGatherConverter(DispatchTestCase):
11+
def test_gather_zero_two_dim(self):
1212
class TestModule(nn.Module):
1313
def __init__(self):
14-
self.index0 = torch.randint(0, 1, (1, 1))
14+
# self.index0 = torch.randint(0, 1, (1, 1))
1515
super().__init__()
1616

17-
def forward(self, x):
18-
index0 = torch.randint(0, 1, (1, 1))
19-
indices = [None, self.index0]
17+
def forward(self, x, indices):
18+
# index0 = torch.randint(0, 1, (1, 1))
19+
# indices = [None, self.index0]
2020
out = torch.ops.aten.gather.default(x, 0, indices)
2121
return out
2222

23-
input = [torch.randn(2, 2)]
23+
index0 = torch.randint(0, 1, (1, 1), dtype=torch.int32)
24+
indices = [None, index0]
25+
input = [torch.randn(2, 2), index0]
2426
self.run_test(
2527
TestModule(),
2628
input,

0 commit comments

Comments
 (0)