@@ -1093,6 +1093,32 @@ TEST(Converters, ATenIndexTensorIdxsNoneConvertsCorrectly) {
1093
1093
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
1094
1094
}
1095
1095
1096
+ TEST (Converters, ATenIndexTensorNoneIdx1ConvertsCorrectly) {
1097
+ const auto graph = R"IR(
1098
+ graph(%x.1 : Tensor,
1099
+ %index0 : Tensor):
1100
+ %5 : NoneType = prim::Constant()
1101
+ %18 : Tensor?[] = prim::ListConstruct(%5, %index0)
1102
+ %19 : Tensor = aten::index(%x.1, %18)
1103
+ return (%19))IR" ;
1104
+
1105
+ auto g = std::make_shared<torch::jit::Graph>();
1106
+ torch::jit::parseIR (graph, g.get ());
1107
+
1108
+ auto in1 = at::randint (1 , 10 , {1 , 3 , 480 , 928 }, {at::kCUDA });
1109
+ auto index0 = at::tensor ({2 , 1 , 0 }, {at::kCUDA }).to (torch::kLong );
1110
+
1111
+ auto index0_trt = index0.to (torch::kInt32 );
1112
+
1113
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
1114
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, index0});
1115
+
1116
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
1117
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, index0_trt});
1118
+
1119
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
1120
+ }
1121
+
1096
1122
TEST (Converters, ATenUnbindConvertsCorrectly) {
1097
1123
const auto graph = R"IR(
1098
1124
graph(%x.1 : Tensor):
0 commit comments