@@ -1147,6 +1147,32 @@ TEST(Converters, ATenIndexTensorIdxsNoneConvertsCorrectly) {
1147
1147
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
1148
1148
}
1149
1149
1150
+ TEST (Converters, ATenIndexTensorNoneIdx1ConvertsCorrectly) {
1151
+ const auto graph = R"IR(
1152
+ graph(%x.1 : Tensor,
1153
+ %index0 : Tensor):
1154
+ %5 : NoneType = prim::Constant()
1155
+ %18 : Tensor?[] = prim::ListConstruct(%5, %index0)
1156
+ %19 : Tensor = aten::index(%x.1, %18)
1157
+ return (%19))IR" ;
1158
+
1159
+ auto g = std::make_shared<torch::jit::Graph>();
1160
+ torch::jit::parseIR (graph, g.get ());
1161
+
1162
+ auto in1 = at::randint (1 , 10 , {1 , 3 , 480 , 928 }, {at::kCUDA });
1163
+ auto index0 = at::tensor ({2 , 1 , 0 }, {at::kCUDA }).to (torch::kLong );
1164
+
1165
+ auto index0_trt = index0.to (torch::kInt32 );
1166
+
1167
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
1168
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, index0});
1169
+
1170
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
1171
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, index0_trt});
1172
+
1173
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
1174
+ }
1175
+
1150
1176
TEST (Converters, ATenUnbindConvertsCorrectly) {
1151
1177
const auto graph = R"IR(
1152
1178
graph(%x.1 : Tensor):
0 commit comments