Skip to content

Commit 4b2e2f9

Browse files
author
Anurag Dixit
committed
chore: rebase with main branch
Signed-off-by: Anurag Dixit <[email protected]>
1 parent b5dbc11 commit 4b2e2f9

File tree

3 files changed

+284
-1
lines changed

3 files changed

+284
-1
lines changed

core/conversion/converters/impl/shuffle.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,104 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
6464
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
6565
return true;
6666
}})
67+
.pattern(
68+
{"aten::unflatten.int(Tensor self, int dim, int[] sizes) -> (Tensor)",
69+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
70+
auto in = args[0].ITensorOrFreeze(ctx);
71+
auto dim = args[1].unwrapToInt();
72+
auto in_shape = util::toVec(in->getDimensions());
73+
std::vector<int64_t> new_shape;
74+
nvinfer1::ITensor* shape_tensor;
75+
if (ctx->input_is_dynamic) {
76+
/*
77+
* In case the dim is negative
78+
* If the dim in negative range is larger than in_shape,
79+
* then it should run into index out of bound error as expected
80+
*/
81+
if (dim < 0) {
82+
dim = in_shape.size() + dim;
83+
}
84+
std::cout << "Dynamic shape case" << std::endl;
85+
LOG_DEBUG("Using dynamic version of reshape layer");
86+
if (args[2].isITensorList()) {
87+
std::cout << "isTensorList case" << std::endl;
88+
LOG_DEBUG("Shape tensor is an ITensorList");
89+
auto expand_shape = args[2].unwrapToITensorList();
90+
auto shape_layer = ctx->net->addShape(*in);
91+
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
92+
auto shape_1d_tensor = shape_layer->getOutput(0);
93+
94+
std::vector<int> before_dim_indices_vector(dim);
95+
std::iota(before_dim_indices_vector.begin(), before_dim_indices_vector.end(), 0);
96+
97+
nvinfer1::ITensor* before_dim_gather_out = nullptr;
98+
if(before_dim_indices_vector.size()){
99+
at::Tensor before_dim_indices = torch::tensor(before_dim_indices_vector).to(torch::kI32);
100+
auto before_dim_indices_out = converters::tensor_to_const(ctx, before_dim_indices);
101+
auto before_dim_gather_layer = ctx->net->addGather(*shape_1d_tensor, *before_dim_indices_out, 0);
102+
TORCHTRT_CHECK(before_dim_gather_layer, "Unable to create gather layer from node: " << *n);
103+
before_dim_gather_out = before_dim_gather_layer->getOutput(0);
104+
}
105+
106+
std::vector<int> after_dim_indices_vector(in_shape.size() - (dim + 1));
107+
std::iota(after_dim_indices_vector.begin(), after_dim_indices_vector.end(), dim + 1);
108+
109+
nvinfer1::ITensor* after_dim_gather_out = nullptr;
110+
if(after_dim_indices_vector.size()){
111+
at::Tensor after_dim_indices = torch::tensor(after_dim_indices_vector).to(torch::kI32);
112+
auto after_dim_indices_out = converters::tensor_to_const(ctx, after_dim_indices);
113+
auto after_dim_gather_layer = ctx->net->addGather(*shape_1d_tensor, *after_dim_indices_out, 0);
114+
TORCHTRT_CHECK(after_dim_gather_layer, "Unable to create gather layer from node: " << *n);
115+
after_dim_gather_out = after_dim_gather_layer->getOutput(0);
116+
}
117+
118+
std::vector<nvinfer1::ITensor*> shape_tensors;
119+
if(before_dim_gather_out){
120+
shape_tensors.push_back(before_dim_gather_out);
121+
}
122+
for(auto new_shape_tensor : expand_shape){
123+
shape_tensors.push_back(new_shape_tensor);
124+
}
125+
if(after_dim_gather_out){
126+
shape_tensors.push_back(after_dim_gather_out);
127+
}
128+
129+
auto shape_cat_layer = ctx->net->addConcatenation(shape_tensors.data(), shape_tensors.size());
130+
TORCHTRT_CHECK(shape_cat_layer, "Unable to create cat layer from node: " << *n);
131+
shape_tensor = shape_cat_layer->getOutput(0);
132+
LOG_DEBUG("Shape tensor shape: " << shape_tensor->getDimensions());
133+
} else if (args[2].isIntList()) {
134+
auto shape_vec = args[2].unwrapToIntList().vec();
135+
// New shape
136+
new_shape.insert(new_shape.end(), in_shape.begin(), in_shape.begin() + dim);
137+
new_shape.insert(new_shape.end(), shape_vec.begin(), shape_vec.end());
138+
new_shape.insert(new_shape.end(), in_shape.begin() + dim + 1, in_shape.end());
139+
140+
shape_tensor = tensor_to_const(ctx, torch::tensor(new_shape).to(torch::kI32));
141+
} else {
142+
LOG_ERROR(
143+
"Invalid IValue type of " << args[2].ivalue_type()
144+
<< " detected for shape tensor from node: " << *n);
145+
}
146+
}
147+
else {
148+
new_shape = torch::unflatten(torch::rand(in_shape), dim, args[2].unwrapToIntList().vec()).sizes().vec();
149+
}
150+
auto shuffle = ctx->net->addShuffle(*in);
151+
shuffle->setName(util::node_info(n).c_str());
152+
TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
153+
154+
if (ctx->input_is_dynamic) {
155+
shuffle->setInput(1, *shape_tensor);
156+
} else {
157+
shuffle->setReshapeDimensions(util::toDims(new_shape));
158+
}
159+
160+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
161+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
162+
163+
return true;
164+
}})
67165
.pattern(
68166
{"aten::reshape(Tensor self, int[] shape) -> (Tensor)",
69167
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

tests/core/conversion/converters/test_shuffle.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,3 +364,55 @@ TEST(Converters, ATenPixelShuffle5DConvertsCorrectly) {
364364

365365
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
366366
}
367+
368+
TEST(Converters, ATenUnflattenConvertsCorrectly) {
369+
const auto graph = R"IR(
370+
graph(%x.1 : Tensor):
371+
%2 : int = prim::Constant[value=1]()
372+
%3 : int = prim::Constant[value=512]()
373+
%4 : int = prim::Constant[value=1]()
374+
%5 : int = prim::Constant[value=1]()
375+
%6 : int[] = prim::ListConstruct(%3, %4, %5)
376+
%7 : Tensor = aten::unflatten(%x.1, %2, %6)
377+
return (%7))IR";
378+
379+
auto g = std::make_shared<torch::jit::Graph>();
380+
torch::jit::parseIR(graph, g.get());
381+
382+
auto in = at::randint(0, 5, {1, 512}, {at::kCUDA});
383+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
384+
385+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
386+
387+
in = at::clone(in);
388+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
389+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
390+
391+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
392+
}
393+
394+
TEST(Converters, ATenUnflattenNegativeDimConvertsCorrectly) {
395+
const auto graph = R"IR(
396+
graph(%x.1 : Tensor):
397+
%2 : int = prim::Constant[value=-1]()
398+
%3 : int = prim::Constant[value=512]()
399+
%4 : int = prim::Constant[value=1]()
400+
%5 : int = prim::Constant[value=1]()
401+
%6 : int[] = prim::ListConstruct(%3, %4, %5)
402+
%7 : Tensor = aten::unflatten(%x.1, %2, %6)
403+
return (%7))IR";
404+
405+
auto g = std::make_shared<torch::jit::Graph>();
406+
torch::jit::parseIR(graph, g.get());
407+
408+
auto in = at::randint(0, 5, {1, 512}, {at::kCUDA});
409+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
410+
411+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
412+
413+
in = at::clone(in);
414+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
415+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
416+
417+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
418+
}

tests/cpp/test_dynamic_size.cpp

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,4 +124,137 @@ TEST(Converters, ATenResizeGetItemDynShapeMulCorrectly) {
124124
auto trt = trt_results[0].reshape(jit_results[0].sizes());
125125

126126
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
127-
}
127+
}
128+
129+
TEST(Converters, ATenUnflattenDynShapeShapeCorrectly) {
130+
const auto graph = R"IR(
131+
graph(%x.1 : Tensor):
132+
%2 : int = prim::Constant[value=1]()
133+
%3 : int = prim::Constant[value=512]()
134+
%4 : int = prim::Constant[value=1]()
135+
%5 : int = prim::Constant[value=1]()
136+
%6 : int[] = prim::ListConstruct(%3, %4, %5)
137+
%7 : Tensor = aten::unflatten(%x.1, %2, %6)
138+
return (%7))IR";
139+
140+
auto g = std::make_shared<torch::jit::Graph>();
141+
142+
torch::jit::parseIR(graph, g.get());
143+
144+
auto in = at::randint(0, 10, {1, 512}, {at::kCUDA});
145+
146+
auto jit_in = at::clone(in);
147+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
148+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
149+
150+
auto trt_in = at::clone(in);
151+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
152+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
153+
154+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
155+
}
156+
157+
TEST(Converters, ATenUnflattenDynShapeNegativeDimsShapeCorrectly) {
158+
const auto graph = R"IR(
159+
graph(%x.1 : Tensor):
160+
%2 : int = prim::Constant[value=-2]()
161+
%3 : int = prim::Constant[value=512]()
162+
%4 : int = prim::Constant[value=1]()
163+
%5 : int = prim::Constant[value=1]()
164+
%6 : int[] = prim::ListConstruct(%3, %4, %5)
165+
%7 : Tensor = aten::unflatten(%x.1, %2, %6)
166+
return (%7))IR";
167+
168+
auto g = std::make_shared<torch::jit::Graph>();
169+
170+
torch::jit::parseIR(graph, g.get());
171+
172+
auto in = at::randint(0, 10, {1, 512, 2}, {at::kCUDA});
173+
174+
auto jit_in = at::clone(in);
175+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
176+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
177+
178+
auto trt_in = at::clone(in);
179+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
180+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
181+
182+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
183+
}
184+
185+
TEST(Converters, ATenUnflattenDynShapeITensorShapeCorrectly) {
186+
const auto graph = R"IR(
187+
graph(%x.1 : Tensor):
188+
%2 : int = prim::Constant[value=1]()
189+
%3 : int = aten::size(%x.1, %2)
190+
%4 : int = prim::Constant[value=256]()
191+
%5 : int = prim::Constant[value=2]()
192+
%6 : int[] = prim::ListConstruct(%4, %5)
193+
%7 : Tensor = aten::unflatten(%x.1, %2, %6)
194+
return (%7))IR";
195+
auto g = std::make_shared<torch::jit::Graph>();
196+
torch::jit::parseIR(graph, g.get());
197+
198+
auto in = at::randint(0, 10, {1, 512, 1}, {at::kCUDA});
199+
200+
auto jit_in = at::clone(in);
201+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
202+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
203+
204+
auto trt_in = at::clone(in);
205+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
206+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
207+
208+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
209+
}
210+
211+
TEST(Converters, ATenUnflattenDynShapeITensorShapeCorrectlyFirstDim) {
212+
const auto graph = R"IR(
213+
graph(%x.1 : Tensor):
214+
%1 : int = prim::Constant[value=0]()
215+
%2 : int = prim::Constant[value=1]()
216+
%3 : int = aten::size(%x.1, %1)
217+
%6 : int[] = prim::ListConstruct(%2, %2, %3, %2, %2)
218+
%7 : Tensor = aten::unflatten(%x.1, %1, %6)
219+
return (%7))IR";
220+
auto g = std::make_shared<torch::jit::Graph>();
221+
torch::jit::parseIR(graph, g.get());
222+
223+
auto in = at::randint(0, 10, {64, 512, 1}, {at::kCUDA});
224+
225+
auto jit_in = at::clone(in);
226+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
227+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
228+
229+
auto trt_in = at::clone(in);
230+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
231+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
232+
233+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
234+
}
235+
236+
TEST(Converters, ATenUnflattenDynShapeITensorShapeCorrectlyLastDim) {
237+
const auto graph = R"IR(
238+
graph(%x.1 : Tensor):
239+
%1 : int = prim::Constant[value=2]()
240+
%2 : int = prim::Constant[value=1]()
241+
%3 : int = aten::size(%x.1, %1)
242+
%5 : int = prim::Constant[value=2]()
243+
%6 : int[] = prim::ListConstruct(%3, %2, %2)
244+
%7 : Tensor = aten::unflatten(%x.1, %5, %6)
245+
return (%7))IR";
246+
auto g = std::make_shared<torch::jit::Graph>();
247+
torch::jit::parseIR(graph, g.get());
248+
249+
auto in = at::randint(0, 10, {1, 512, 9}, {at::kCUDA});
250+
251+
auto jit_in = at::clone(in);
252+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
253+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
254+
255+
auto trt_in = at::clone(in);
256+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
257+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
258+
259+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
260+
}

0 commit comments

Comments
 (0)