@@ -126,62 +126,6 @@ TEST(Converters, ATenResizeGetItemDynShapeMulCorrectly) {
126
126
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
127
127
}
128
128
129
- TEST (Converters, ATenUnflattenDynShapeShapeCorrectly) {
130
- const auto graph = R"IR(
131
- graph(%x.1 : Tensor):
132
- %2 : int = prim::Constant[value=1]()
133
- %3 : int = prim::Constant[value=512]()
134
- %4 : int = prim::Constant[value=1]()
135
- %5 : int = prim::Constant[value=1]()
136
- %6 : int[] = prim::ListConstruct(%3, %4, %5)
137
- %7 : Tensor = aten::unflatten(%x.1, %2, %6)
138
- return (%7))IR" ;
139
-
140
- auto g = std::make_shared<torch::jit::Graph>();
141
-
142
- torch::jit::parseIR (graph, g.get ());
143
-
144
- auto in = at::randint (0 , 10 , {1 , 512 }, {at::kCUDA });
145
-
146
- auto jit_in = at::clone (in);
147
- auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
148
- auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
149
-
150
- auto trt_in = at::clone (in);
151
- params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
152
- auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic (g, params, {in}, true );
153
-
154
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
155
- }
156
-
157
- TEST (Converters, ATenUnflattenDynShapeNegativeDimsShapeCorrectly) {
158
- const auto graph = R"IR(
159
- graph(%x.1 : Tensor):
160
- %2 : int = prim::Constant[value=-2]()
161
- %3 : int = prim::Constant[value=512]()
162
- %4 : int = prim::Constant[value=1]()
163
- %5 : int = prim::Constant[value=1]()
164
- %6 : int[] = prim::ListConstruct(%3, %4, %5)
165
- %7 : Tensor = aten::unflatten(%x.1, %2, %6)
166
- return (%7))IR" ;
167
-
168
- auto g = std::make_shared<torch::jit::Graph>();
169
-
170
- torch::jit::parseIR (graph, g.get ());
171
-
172
- auto in = at::randint (0 , 10 , {1 , 512 , 2 }, {at::kCUDA });
173
-
174
- auto jit_in = at::clone (in);
175
- auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
176
- auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
177
-
178
- auto trt_in = at::clone (in);
179
- params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
180
- auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic (g, params, {in}, true );
181
-
182
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
183
- }
184
-
185
129
TEST (Converters, ATenUnflattenDynShapeITensorShapeCorrectly) {
186
130
const auto graph = R"IR(
187
131
graph(%x.1 : Tensor):
0 commit comments