Skip to content

Commit 5b6f9a9

Browse files
committed
Linting fix and adding test case
1 parent e593c9b commit 5b6f9a9

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,9 @@ def index(
253253
dim_tensor_list[adv_indx_indices[i]],
254254
)
255255

256-
gather_out = gather(ctx, target, source_ir, name, flatten_tensor, 0, cum_adv_index)
256+
gather_out = gather(
257+
ctx, target, source_ir, name, flatten_tensor, 0, cum_adv_index
258+
)
257259
_LOGGER.debug(f"The shape after cumultative gather is {gather_out.shape}")
258260
_LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index}")
259261

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import operator
2+
3+
import torch
4+
import torch.nn as nn
5+
from harness import DispatchTestCase
6+
from torch.testing._internal.common_utils import run_tests
7+
from torch_tensorrt import Input
8+
9+
10+
class TestIndexConverter(DispatchTestCase):
11+
def test_index_zero_two_dim(self):
12+
class TestModule(nn.Module):
13+
def __init__(self):
14+
self.index0 = torch.randint(0, 1, (1, 1))
15+
super().__init__()
16+
17+
def forward(self, x):
18+
index0 = torch.randint(0, 1, (1, 1))
19+
indices = [None, self.index0]
20+
out = torch.ops.aten.gather.default(x, 0, indices)
21+
return out
22+
23+
input = [torch.randn(2, 2)]
24+
self.run_test(
25+
TestModule(),
26+
input,
27+
)
28+
29+
if __name__ == "__main__":
30+
run_tests()

0 commit comments

Comments
 (0)