Skip to content

Commit 69b3d79

Browse files
committed
Add support for index_select
1 parent deda87b commit 69b3d79

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,29 @@ auto select_registrations TORCHTRT_UNUSED =
180180
return true;
181181
}})
182182
.pattern(
183+
{"aten::index_select(Tensor self, int dim, Tensor index) -> Tensor",
184+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
185+
auto in = args[0].ITensorOrFreeze(ctx);
186+
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
187+
auto dim = args[1].unwrapToInt();
188+
// Handle negative axis by refering to nbDims of input Tensor
189+
dim = dim < 0 ? dim + maxDim : dim;
190+
auto index = args[2].ITensorOrFreeze(ctx);
191+
192+
LOG_DEBUG("Gather input dimensions: " << in->getDimensions());
193+
LOG_DEBUG("Dimension to select: " << dim);
194+
LOG_DEBUG("Index dimensions: " << index->getDimensions());
195+
196+
auto gather_layer = ctx->net->addGather(*in, *index, dim);
197+
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
198+
auto out = gather_layer->getOutput(0);
199+
LOG_DEBUG("Gather tensor shape: " << out->getDimensions());
200+
201+
out = ctx->AssociateValueAndTensor(n->outputs()[0], out);
202+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
203+
return true;
204+
}})
205+
.pattern(
183206
{"aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)",
184207
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
185208
auto in = args[0].ITensor();

tests/core/conversion/converters/test_select.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,60 @@ TEST(Converters, ATenSelectEmptyTensorConvertsCorrectly) {
165165
ASSERT_TRUE(torch_tensorrt::tests::util::sameShape(jit_results[0], trt_results[0]));
166166
}
167167

168+
TEST(Converters, ATenIndexSelectConvertsCorrectly) {
169+
const auto graph = R"IR(
170+
graph(%0 : Tensor, %index : Int (2)):
171+
%2 : int = prim::Constant[value=0]()
172+
%3 : Tensor = aten::index_select(%0, %2, %index)
173+
return (%3))IR";
174+
auto g = std::make_shared<torch::jit::Graph>();
175+
torch::jit::parseIR(graph, g.get());
176+
auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
177+
auto index = at::randint(0, 4, {2}, {at::kCUDA}).to(torch::kI32);
178+
179+
auto jit_in = at::clone(in);
180+
auto jit_index = at::clone(index);
181+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_index});
182+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
183+
184+
auto trt_in = at::clone(in);
185+
auto trt_index = at::clone(index);
186+
auto trt_params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_index});
187+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, trt_params, {trt_in});
188+
189+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
190+
191+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
192+
}
193+
194+
TEST(Converters, ATenIndexSelectNegativeDimConvertsCorrectly) {
195+
const auto graph = R"IR(
196+
graph(%0 : Tensor, %index : Int (5)):
197+
%2 : int = prim::Constant[value=-1]()
198+
%3 : Tensor = aten::index_select(%0, %2, %index)
199+
return (%3))IR";
200+
auto g = std::make_shared<torch::jit::Graph>();
201+
202+
torch::jit::parseIR(graph, g.get());
203+
204+
auto in = at::randint(1, 10, {5, 3, 9}, {at::kCUDA});
205+
auto index = at::randint(0, 9, {5}, {at::kCUDA}).to(torch::kI32);
206+
207+
auto jit_in = at::clone(in);
208+
auto jit_index = at::clone(index);
209+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_index});
210+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
211+
212+
auto trt_in = at::clone(in);
213+
auto trt_index = at::clone(index);
214+
auto trt_params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_index});
215+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, trt_params, {trt_in});
216+
217+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
218+
219+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
220+
}
221+
168222
TEST(Converters, ATenNarrowStartScalarConvertsCorrectly) {
169223
const auto graph = R"IR(
170224
graph(%x.1 : Tensor):

0 commit comments

Comments
 (0)