@@ -124,4 +124,137 @@ TEST(Converters, ATenResizeGetItemDynShapeMulCorrectly) {
124
124
auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
125
125
126
126
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
127
- }
127
+ }
128
+
129
+ TEST (Converters, ATenUnflattenDynShapeShapeCorrectly) {
130
+ const auto graph = R"IR(
131
+ graph(%x.1 : Tensor):
132
+ %2 : int = prim::Constant[value=1]()
133
+ %3 : int = prim::Constant[value=512]()
134
+ %4 : int = prim::Constant[value=1]()
135
+ %5 : int = prim::Constant[value=1]()
136
+ %6 : int[] = prim::ListConstruct(%3, %4, %5)
137
+ %7 : Tensor = aten::unflatten(%x.1, %2, %6)
138
+ return (%7))IR" ;
139
+
140
+ auto g = std::make_shared<torch::jit::Graph>();
141
+
142
+ torch::jit::parseIR (graph, g.get ());
143
+
144
+ auto in = at::randint (0 , 10 , {1 , 512 }, {at::kCUDA });
145
+
146
+ auto jit_in = at::clone (in);
147
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
148
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
149
+
150
+ auto trt_in = at::clone (in);
151
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
152
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic (g, params, {in}, true );
153
+
154
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
155
+ }
156
+
157
+ TEST (Converters, ATenUnflattenDynShapeNegativeDimsShapeCorrectly) {
158
+ const auto graph = R"IR(
159
+ graph(%x.1 : Tensor):
160
+ %2 : int = prim::Constant[value=-2]()
161
+ %3 : int = prim::Constant[value=512]()
162
+ %4 : int = prim::Constant[value=1]()
163
+ %5 : int = prim::Constant[value=1]()
164
+ %6 : int[] = prim::ListConstruct(%3, %4, %5)
165
+ %7 : Tensor = aten::unflatten(%x.1, %2, %6)
166
+ return (%7))IR" ;
167
+
168
+ auto g = std::make_shared<torch::jit::Graph>();
169
+
170
+ torch::jit::parseIR (graph, g.get ());
171
+
172
+ auto in = at::randint (0 , 10 , {1 , 512 , 2 }, {at::kCUDA });
173
+
174
+ auto jit_in = at::clone (in);
175
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
176
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
177
+
178
+ auto trt_in = at::clone (in);
179
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
180
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic (g, params, {in}, true );
181
+
182
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
183
+ }
184
+
185
+ TEST (Converters, ATenUnflattenDynShapeITensorShapeCorrectly) {
186
+ const auto graph = R"IR(
187
+ graph(%x.1 : Tensor):
188
+ %2 : int = prim::Constant[value=1]()
189
+ %3 : int = aten::size(%x.1, %2)
190
+ %4 : int = prim::Constant[value=256]()
191
+ %5 : int = prim::Constant[value=2]()
192
+ %6 : int[] = prim::ListConstruct(%4, %5)
193
+ %7 : Tensor = aten::unflatten(%x.1, %2, %6)
194
+ return (%7))IR" ;
195
+ auto g = std::make_shared<torch::jit::Graph>();
196
+ torch::jit::parseIR (graph, g.get ());
197
+
198
+ auto in = at::randint (0 , 10 , {1 , 512 , 1 }, {at::kCUDA });
199
+
200
+ auto jit_in = at::clone (in);
201
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
202
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
203
+
204
+ auto trt_in = at::clone (in);
205
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
206
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic (g, params, {in}, true );
207
+
208
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
209
+ }
210
+
211
+ TEST (Converters, ATenUnflattenDynShapeITensorShapeCorrectlyFirstDim) {
212
+ const auto graph = R"IR(
213
+ graph(%x.1 : Tensor):
214
+ %1 : int = prim::Constant[value=0]()
215
+ %2 : int = prim::Constant[value=1]()
216
+ %3 : int = aten::size(%x.1, %1)
217
+ %6 : int[] = prim::ListConstruct(%2, %2, %3, %2, %2)
218
+ %7 : Tensor = aten::unflatten(%x.1, %1, %6)
219
+ return (%7))IR" ;
220
+ auto g = std::make_shared<torch::jit::Graph>();
221
+ torch::jit::parseIR (graph, g.get ());
222
+
223
+ auto in = at::randint (0 , 10 , {64 , 512 , 1 }, {at::kCUDA });
224
+
225
+ auto jit_in = at::clone (in);
226
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
227
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
228
+
229
+ auto trt_in = at::clone (in);
230
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
231
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic (g, params, {in}, true );
232
+
233
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
234
+ }
235
+
236
+ TEST (Converters, ATenUnflattenDynShapeITensorShapeCorrectlyLastDim) {
237
+ const auto graph = R"IR(
238
+ graph(%x.1 : Tensor):
239
+ %1 : int = prim::Constant[value=2]()
240
+ %2 : int = prim::Constant[value=1]()
241
+ %3 : int = aten::size(%x.1, %1)
242
+ %5 : int = prim::Constant[value=2]()
243
+ %6 : int[] = prim::ListConstruct(%3, %2, %2)
244
+ %7 : Tensor = aten::unflatten(%x.1, %5, %6)
245
+ return (%7))IR" ;
246
+ auto g = std::make_shared<torch::jit::Graph>();
247
+ torch::jit::parseIR (graph, g.get ());
248
+
249
+ auto in = at::randint (0 , 10 , {1 , 512 , 9 }, {at::kCUDA });
250
+
251
+ auto jit_in = at::clone (in);
252
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
253
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
254
+
255
+ auto trt_in = at::clone (in);
256
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
257
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic (g, params, {in}, true );
258
+
259
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
260
+ }
0 commit comments