Skip to content

Commit 2d8ab5c

Browse files
committed
Add expand layer, expand_as and repeat layer functionality
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent b787c5e commit 2d8ab5c

File tree

4 files changed

+360
-0
lines changed

4 files changed

+360
-0
lines changed

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ cc_library(
3939
"impl/constant.cpp",
4040
"impl/conv_deconv.cpp",
4141
"impl/element_wise.cpp",
42+
"impl/expand.cpp",
4243
"impl/linear.cpp",
4344
"impl/matrix_multiply.cpp",
4445
"impl/pooling.cpp",
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#include "NvInfer.h"
2+
#include "core/conversion/converters/converters.h"
3+
#include "core/conversion/tensorcontainer/TensorContainer.h"
4+
#include "core/util/prelude.h"
5+
#include "core/util/trt_util.h"
6+
#include "torch/torch.h"
7+
8+
#include <ATen/ATen.h>
9+
#include <vector>
10+
11+
namespace trtorch {
12+
namespace core {
13+
namespace conversion {
14+
namespace converters {
15+
namespace impl {
16+
namespace {
17+
18+
bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, nvinfer1::Dims expandedDims) {
19+
auto input_dims = in->getDimensions();
20+
TRTORCH_CHECK(
21+
input_dims.nbDims <= expandedDims.nbDims,
22+
"Number of dimensions of the desired expansion must be greater than or equal to the number of input dimensions");
23+
24+
// Validate the expansion. Eg: an input of [3, 1] can be expanded to [1, 3, 4] but not [3, 4, 1]
25+
for (int i = expandedDims.nbDims - 1; i >= 0; --i) {
26+
int64_t offset = expandedDims.nbDims - 1 - i;
27+
int64_t dim = input_dims.nbDims - 1 - offset;
28+
int64_t size = (dim >= 0) ? input_dims.d[dim] : 1;
29+
int64_t targetSize = expandedDims.d[i];
30+
if (size != targetSize) {
31+
if (size != 1) {
32+
TRTORCH_THROW_ERROR(
33+
"The expanded size of tensor (" << targetSize << ")"
34+
<< " must match the existing size (" << size << ")"
35+
<< " at dimension " << i);
36+
}
37+
}
38+
}
39+
40+
auto num_expand_dims = expandedDims.nbDims - input_dims.nbDims;
41+
if (num_expand_dims > 0) {
42+
nvinfer1::Dims reshape_dims;
43+
reshape_dims.nbDims = expandedDims.nbDims;
44+
for (int i = 0; i < num_expand_dims; i++) {
45+
reshape_dims.d[i] = 1;
46+
}
47+
for (int i = 0; i < input_dims.nbDims; i++) {
48+
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
49+
}
50+
// Add a reshape layer to expand dims
51+
auto reshape_layer = ctx->net->addShuffle(*in);
52+
reshape_layer->setReshapeDimensions(reshape_dims);
53+
in = reshape_layer->getOutput(0);
54+
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
55+
}
56+
57+
// Start the slicing from beginning of tensor since this is an expand layer
58+
std::vector<int64_t> start_vec(expandedDims.nbDims, 0);
59+
auto start_offset = util::toDims(c10::IntArrayRef(start_vec));
60+
61+
// Set the stride of non singleton dimension to 1
62+
std::vector<int64_t> strides_vec(expandedDims.nbDims, 0);
63+
for (int i = 0; i < expandedDims.nbDims; i++) {
64+
strides_vec[i] = (in->getDimensions().d[i] != 1);
65+
}
66+
67+
auto strides = util::toDims(c10::IntArrayRef(strides_vec));
68+
// Slice layer does the expansion in TRT. Desired output size is specified by expandedDims
69+
auto slice_layer = ctx->net->addSlice(*in, start_offset, expandedDims, strides);
70+
slice_layer->setName(util::node_info(n).c_str());
71+
72+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], slice_layer->getOutput(0));
73+
74+
LOG_DEBUG("Expand layer output tensor shape: " << out->getDimensions());
75+
76+
return true;
77+
}
78+
79+
auto expand_registrations TRTORCH_UNUSED =
80+
RegisterNodeConversionPatterns()
81+
.pattern({"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))",
82+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
83+
auto in = args[0].ITensor();
84+
auto input_dims = in->getDimensions();
85+
auto expanded_size = args[1].unwrapToIntList();
86+
auto expandedDims = util::toDims(expanded_size);
87+
LOG_DEBUG("(expand layer) Expand input from " << input_dims << " to " << expandedDims);
88+
return add_expand(ctx, n, in, expandedDims);
89+
}})
90+
.pattern({"aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))",
91+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
92+
// TODO: Currently expand supports static shapes. Need to explore if the same code can be extended
93+
// to dynamic expansion.
94+
auto in = args[0].ITensor();
95+
auto input_dims = in->getDimensions();
96+
auto targetTensor = args[1].ITensor();
97+
auto targetDims = targetTensor->getDimensions();
98+
LOG_DEBUG("(expand_as layer) Expand input from " << input_dims << " to " << targetDims);
99+
return add_expand(ctx, n, in, targetDims);
100+
}})
101+
.pattern({"aten::repeat(Tensor self, int[] repeats) -> (Tensor)",
102+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
103+
auto in = args[0].ITensor();
104+
auto input_dims = in->getDimensions();
105+
auto repeats = args[1].unwrapToIntList().vec();
106+
TRTORCH_CHECK(
107+
repeats.size() >= input_dims.nbDims,
108+
"Number of repeat dimensions cannot be smaller than number of input dimensions");
109+
auto num_expand_dims = repeats.size() - input_dims.nbDims;
110+
if (num_expand_dims > 0) {
111+
nvinfer1::Dims reshape_dims;
112+
reshape_dims.nbDims = repeats.size();
113+
for (int i = 0; i < num_expand_dims; i++) {
114+
reshape_dims.d[i] = 1;
115+
}
116+
for (int i = 0; i < input_dims.nbDims; i++) {
117+
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
118+
}
119+
// Add a reshape layer to expand dims
120+
auto reshape_layer = ctx->net->addShuffle(*in);
121+
reshape_layer->setReshapeDimensions(reshape_dims);
122+
in = reshape_layer->getOutput(0);
123+
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
124+
}
125+
126+
LOG_DEBUG("Repeats: " << repeats);
127+
128+
// Concat across all repeat axes.
129+
// TODO: Implementation might not be performant. Explore other strategies to improve performance.
130+
for (int i = repeats.size() - 1; i >= 0; --i) {
131+
std::vector<nvinfer1::ITensor*> tensors_vec;
132+
for (int j = 0; j < repeats[i]; j++) {
133+
tensors_vec.push_back(in);
134+
}
135+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
136+
concat_layer->setAxis(i);
137+
in = concat_layer->getOutput(0);
138+
}
139+
140+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
141+
142+
LOG_DEBUG("Repeat layer output tensor shape: " << in->getDimensions());
143+
144+
return true;
145+
}});
146+
147+
} // namespace
148+
} // namespace impl
149+
} // namespace converters
150+
} // namespace conversion
151+
} // namespace core
152+
} // namespace trtorch

tests/core/conversion/converters/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ converter_test(
2727
name = "test_element_wise"
2828
)
2929

30+
converter_test(
31+
name = "test_expand"
32+
)
33+
3034
converter_test(
3135
name = "test_linear"
3236
)
@@ -78,6 +82,7 @@ test_suite(
7882
":test_batch_norm",
7983
":test_conv_deconv",
8084
":test_element_wise",
85+
":test_expand",
8186
":test_linear",
8287
":test_matrix_multiply",
8388
":test_pooling",
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
#include <torch/torch.h>
2+
#include <string>
3+
#include "core/compiler.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
8+
TEST(Converters, ATenExpandSameDimConvertsCorrectly) {
9+
const auto graph = R"IR(
10+
graph(%x.1 : Tensor):
11+
%2 : int[] = prim::Constant[value=[3, 4]]()
12+
%3 : bool = prim::Constant[value=0]()
13+
%4 : Tensor = aten::expand(%x.1, %2, %3)
14+
return (%4))IR";
15+
16+
auto g = std::make_shared<torch::jit::Graph>();
17+
18+
torch::jit::parseIR(graph, &*g);
19+
20+
auto in = at::randint(1, 10, {3, 1}, {at::kCUDA});
21+
22+
auto jit_in = at::clone(in);
23+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
24+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
25+
26+
auto trt_in = at::clone(in);
27+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
28+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
29+
30+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
31+
32+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
33+
}
34+
35+
TEST(Converters, ATenExpandTileConvertsCorrectly) {
36+
const auto graph = R"IR(
37+
graph(%x.1 : Tensor):
38+
%2 : int[] = prim::Constant[value=[2, 3, 1]]()
39+
%3 : bool = prim::Constant[value=0]()
40+
%4 : Tensor = aten::expand(%x.1, %2, %3)
41+
return (%4))IR";
42+
43+
auto g = std::make_shared<torch::jit::Graph>();
44+
45+
torch::jit::parseIR(graph, &*g);
46+
47+
auto in = at::randint(1, 10, {3, 1}, {at::kCUDA});
48+
49+
auto jit_in = at::clone(in);
50+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
51+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
52+
53+
auto trt_in = at::clone(in);
54+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
55+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
56+
57+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
58+
59+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
60+
}
61+
62+
TEST(Converters, ATenExpandTileLastConvertsCorrectly) {
63+
const auto graph = R"IR(
64+
graph(%x.1 : Tensor):
65+
%2 : int[] = prim::Constant[value=[1, 3, 4]]()
66+
%3 : bool = prim::Constant[value=0]()
67+
%4 : Tensor = aten::expand(%x.1, %2, %3)
68+
return (%4))IR";
69+
70+
auto g = std::make_shared<torch::jit::Graph>();
71+
72+
torch::jit::parseIR(graph, &*g);
73+
74+
auto in = at::randint(1, 10, {3, 1}, {at::kCUDA});
75+
76+
auto jit_in = at::clone(in);
77+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
78+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
79+
80+
auto trt_in = at::clone(in);
81+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
82+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
83+
84+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
85+
86+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
87+
}
88+
89+
/* Expand_as layer takes two inputs and only dimensions of second input are
90+
actually used. TRT prunes away the second input. This will result in internal
91+
failure from TRT. To avoid unrelated issues, we add a dummy operation which
92+
outputs second_input+2 as a second output. The second input is preserved.
93+
*/
94+
TEST(Converters, ATenExpandASConvertsCorrectly) {
95+
const auto graph = R"IR(
96+
graph(%x.1 : Tensor,
97+
%y.1 : Tensor):
98+
%3 : int = prim::Constant[value=1]()
99+
%4 : int = prim::Constant[value=2]()
100+
%5 : Tensor = aten::expand_as(%x.1, %y.1)
101+
%6 : Tensor = aten::add(%y.1, %4, %3)
102+
return (%5, %6))IR";
103+
104+
auto g = std::make_shared<torch::jit::Graph>();
105+
106+
torch::jit::parseIR(graph, &*g);
107+
108+
auto in = at::randint(1, 10, {3, 1}, {at::kCUDA});
109+
auto target_in = at::randint(1, 10, {2, 3, 1}, {at::kCUDA});
110+
111+
auto jit_in = at::clone(in);
112+
auto jit_target_in = at::clone(target_in);
113+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
114+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in, jit_target_in});
115+
116+
auto trt_in = at::clone(jit_in);
117+
auto trt_target_in = at::clone(jit_target_in);
118+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
119+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in, trt_target_in});
120+
121+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
122+
123+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
124+
}
125+
126+
TEST(Converters, ATenRepeatConvertsCorrectly) {
127+
const auto graph = R"IR(
128+
graph(%x.1 : Tensor):
129+
%2 : int[] = prim::Constant[value=[4, 2]]()
130+
%3 : Tensor = aten::repeat(%x.1, %2)
131+
return (%3))IR";
132+
133+
auto g = std::make_shared<torch::jit::Graph>();
134+
135+
torch::jit::parseIR(graph, &*g);
136+
137+
auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});
138+
139+
auto jit_in = at::clone(in);
140+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
141+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
142+
143+
auto trt_in = at::clone(jit_in);
144+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
145+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
146+
147+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
148+
149+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
150+
}
151+
152+
TEST(Converters, ATenRepeat3dConvertsCorrectly) {
153+
const auto graph = R"IR(
154+
graph(%x.1 : Tensor):
155+
%2 : int[] = prim::Constant[value=[2, 2, 2]]()
156+
%3 : Tensor = aten::repeat(%x.1, %2)
157+
return (%3))IR";
158+
159+
auto g = std::make_shared<torch::jit::Graph>();
160+
161+
torch::jit::parseIR(graph, &*g);
162+
163+
auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});
164+
165+
auto jit_in = at::clone(in);
166+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
167+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
168+
169+
auto trt_in = at::clone(jit_in);
170+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
171+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
172+
173+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
174+
175+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
176+
}
177+
178+
TEST(Converters, ATenRepeatExtraDimsConvertsCorrectly) {
179+
const auto graph = R"IR(
180+
graph(%x.1 : Tensor):
181+
%2 : int[] = prim::Constant[value=[1, 3, 2]]()
182+
%3 : Tensor = aten::repeat(%x.1, %2)
183+
return (%3))IR";
184+
185+
auto g = std::make_shared<torch::jit::Graph>();
186+
187+
torch::jit::parseIR(graph, &*g);
188+
189+
auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});
190+
191+
auto jit_in = at::clone(in);
192+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
193+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
194+
195+
auto trt_in = at::clone(jit_in);
196+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
197+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
198+
199+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
200+
201+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
202+
}

0 commit comments

Comments
 (0)