Skip to content

Commit d01a40b

Browse files
authored
Merge pull request #1700 from mfeliz-cruise/michael.feliz/aten_index_fix
[fix] resolve issue for single non-batch index tensor in aten::index
2 parents e657c1c + 11e4830 commit d01a40b

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ auto select_registrations TORCHTRT_UNUSED =
360360

361361
// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
362362
// from
363-
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0);
363+
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, adv_idx_indices[0]);
364364
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
365365
auto gather_out = gather_layer->getOutput(0);
366366

tests/core/conversion/converters/test_select.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,6 +1147,32 @@ TEST(Converters, ATenIndexTensorIdxsNoneConvertsCorrectly) {
11471147
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
11481148
}
11491149

1150+
TEST(Converters, ATenIndexTensorNoneIdx1ConvertsCorrectly) {
1151+
const auto graph = R"IR(
1152+
graph(%x.1 : Tensor,
1153+
%index0 : Tensor):
1154+
%5 : NoneType = prim::Constant()
1155+
%18 : Tensor?[] = prim::ListConstruct(%5, %index0)
1156+
%19 : Tensor = aten::index(%x.1, %18)
1157+
return (%19))IR";
1158+
1159+
auto g = std::make_shared<torch::jit::Graph>();
1160+
torch::jit::parseIR(graph, g.get());
1161+
1162+
auto in1 = at::randint(1, 10, {1, 3, 480, 928}, {at::kCUDA});
1163+
auto index0 = at::tensor({2, 1, 0}, {at::kCUDA}).to(torch::kLong);
1164+
1165+
auto index0_trt = index0.to(torch::kInt32);
1166+
1167+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
1168+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0});
1169+
1170+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
1171+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt});
1172+
1173+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
1174+
}
1175+
11501176
TEST(Converters, ATenUnbindConvertsCorrectly) {
11511177
const auto graph = R"IR(
11521178
graph(%x.1 : Tensor):

0 commit comments

Comments
 (0)