Skip to content

Commit 11e4830

Browse files
committed
aten::index fix
1 parent d35fe2a commit 11e4830

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ auto select_registrations TORCHTRT_UNUSED =
337337

338338
// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
339339
// from
340-
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0);
340+
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, adv_idx_indices[0]);
341341
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
342342
auto gather_out = gather_layer->getOutput(0);
343343

tests/core/conversion/converters/test_select.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,6 +1093,32 @@ TEST(Converters, ATenIndexTensorIdxsNoneConvertsCorrectly) {
10931093
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
10941094
}
10951095

1096+
TEST(Converters, ATenIndexTensorNoneIdx1ConvertsCorrectly) {
1097+
const auto graph = R"IR(
1098+
graph(%x.1 : Tensor,
1099+
%index0 : Tensor):
1100+
%5 : NoneType = prim::Constant()
1101+
%18 : Tensor?[] = prim::ListConstruct(%5, %index0)
1102+
%19 : Tensor = aten::index(%x.1, %18)
1103+
return (%19))IR";
1104+
1105+
auto g = std::make_shared<torch::jit::Graph>();
1106+
torch::jit::parseIR(graph, g.get());
1107+
1108+
auto in1 = at::randint(1, 10, {1, 3, 480, 928}, {at::kCUDA});
1109+
auto index0 = at::tensor({2, 1, 0}, {at::kCUDA}).to(torch::kLong);
1110+
1111+
auto index0_trt = index0.to(torch::kInt32);
1112+
1113+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
1114+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0});
1115+
1116+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
1117+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt});
1118+
1119+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
1120+
}
1121+
10961122
TEST(Converters, ATenUnbindConvertsCorrectly) {
10971123
const auto graph = R"IR(
10981124
graph(%x.1 : Tensor):

0 commit comments

Comments
 (0)