@@ -165,6 +165,60 @@ TEST(Converters, ATenSelectEmptyTensorConvertsCorrectly) {
165
165
ASSERT_TRUE (torch_tensorrt::tests::util::sameShape (jit_results[0 ], trt_results[0 ]));
166
166
}
167
167
168
+ TEST (Converters, ATenIndexSelectConvertsCorrectly) {
169
+ const auto graph = R"IR(
170
+ graph(%0 : Tensor, %index : Int (2)):
171
+ %2 : int = prim::Constant[value=0]()
172
+ %3 : Tensor = aten::index_select(%0, %2, %index)
173
+ return (%3))IR" ;
174
+ auto g = std::make_shared<torch::jit::Graph>();
175
+ torch::jit::parseIR (graph, g.get ());
176
+ auto in = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
177
+ auto index = at::randint (0 , 4 , {2 }, {at::kCUDA }).to (torch::kI32 );
178
+
179
+ auto jit_in = at::clone (in);
180
+ auto jit_index = at::clone (index);
181
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {jit_index});
182
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
183
+
184
+ auto trt_in = at::clone (in);
185
+ auto trt_index = at::clone (index);
186
+ auto trt_params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {trt_index});
187
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, trt_params, {trt_in});
188
+
189
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
190
+
191
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
192
+ }
193
+
194
+ TEST (Converters, ATenIndexSelectNegativeDimConvertsCorrectly) {
195
+ const auto graph = R"IR(
196
+ graph(%0 : Tensor, %index : Int (5)):
197
+ %2 : int = prim::Constant[value=-1]()
198
+ %3 : Tensor = aten::index_select(%0, %2, %index)
199
+ return (%3))IR" ;
200
+ auto g = std::make_shared<torch::jit::Graph>();
201
+
202
+ torch::jit::parseIR (graph, g.get ());
203
+
204
+ auto in = at::randint (1 , 10 , {5 , 3 , 9 }, {at::kCUDA });
205
+ auto index = at::randint (0 , 9 , {5 }, {at::kCUDA }).to (torch::kI32 );
206
+
207
+ auto jit_in = at::clone (in);
208
+ auto jit_index = at::clone (index);
209
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {jit_index});
210
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
211
+
212
+ auto trt_in = at::clone (in);
213
+ auto trt_index = at::clone (index);
214
+ auto trt_params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {trt_index});
215
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, trt_params, {trt_in});
216
+
217
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
218
+
219
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
220
+ }
221
+
168
222
TEST (Converters, ATenNarrowStartScalarConvertsCorrectly) {
169
223
const auto graph = R"IR(
170
224
graph(%x.1 : Tensor):
0 commit comments