Skip to content

Commit 05652b8

Browse files
authored
Merge pull request #301 from NVIDIA/squeeze_unsqueeze
Add support for aten::squeeze, aten::unsqueeze
2 parents b433a53 + 4d78e51 commit 05652b8

File tree

8 files changed

+183
-2
lines changed

8 files changed

+183
-2
lines changed

core/conversion/converters/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ cc_library(
4848
"impl/unary.cpp",
4949
"impl/interpolate.cpp",
5050
"impl/select.cpp",
51+
"impl/squeeze.cpp",
5152
"impl/stack.cpp",
52-
"impl/lstm_cell.cpp"
53+
"impl/lstm_cell.cpp",
54+
"impl/unsqueeze.cpp",
5355
],
5456
deps = [
5557
"@tensorrt//:nvinfer",
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include "NvInfer.h"
2+
#include "core/conversion/converters/converters.h"
3+
#include "core/conversion/tensorcontainer/TensorContainer.h"
4+
#include "core/util/prelude.h"
5+
#include "torch/torch.h"
6+
7+
#include <ATen/ATen.h>
8+
#include <vector>
9+
10+
namespace trtorch {
11+
namespace core {
12+
namespace conversion {
13+
namespace converters {
14+
namespace impl {
15+
namespace {
16+
17+
auto squeeze_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
18+
{"aten::squeeze.dim(Tensor(a) self, int dim) -> (Tensor(a))",
19+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
20+
auto self = args[0].ITensorOrFreeze(ctx);
21+
auto dim = args[1].unwrapToInt();
22+
23+
auto selfDim = util::toVec(self->getDimensions());
24+
if (dim < 0) {
25+
dim = selfDim.size() + dim;
26+
}
27+
28+
auto shuffle_layer = ctx->net->addShuffle(*self);
29+
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
30+
shuffle_layer->setReshapeDimensions(util::squeezeDims(self->getDimensions(), dim));
31+
32+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_layer->getOutput(0));
33+
34+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
35+
36+
return true;
37+
}});
38+
39+
} // namespace
40+
} // namespace impl
41+
} // namespace converters
42+
} // namespace conversion
43+
} // namespace core
44+
} // namespace trtorch
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include "NvInfer.h"
2+
#include "core/conversion/converters/converters.h"
3+
#include "core/conversion/tensorcontainer/TensorContainer.h"
4+
#include "core/util/prelude.h"
5+
#include "torch/torch.h"
6+
7+
#include <ATen/ATen.h>
8+
#include <vector>
9+
10+
namespace trtorch {
11+
namespace core {
12+
namespace conversion {
13+
namespace converters {
14+
namespace impl {
15+
namespace {
16+
17+
auto unsqueeze_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
18+
{"aten::unsqueeze(Tensor(a) self, int dim) -> (Tensor(a))",
19+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
20+
auto self = args[0].ITensorOrFreeze(ctx);
21+
auto dim = args[1].unwrapToInt();
22+
23+
auto selfDim = util::toVec(self->getDimensions());
24+
if (dim < 0) {
25+
dim = selfDim.size() + dim;
26+
}
27+
28+
auto shuffle_layer = ctx->net->addShuffle(*self);
29+
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
30+
shuffle_layer->setReshapeDimensions(util::unsqueezeDims(self->getDimensions(), dim));
31+
32+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_layer->getOutput(0));
33+
34+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
35+
36+
return true;
37+
}});
38+
39+
} // namespace
40+
} // namespace impl
41+
} // namespace converters
42+
} // namespace conversion
43+
} // namespace core
44+
} // namespace trtorch

core/util/trt_util.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,34 @@ nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos) {
186186
return dims;
187187
}
188188

189+
nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos) {
190+
// acceptable range for pos is [0, d.nbDims]
191+
TRTORCH_ASSERT(pos >= 0 && pos <= d.nbDims, "ERROR: Index to squeeze is out of bounds.");
192+
193+
nvinfer1::Dims dims;
194+
195+
int i = 0;
196+
int j = 0;
197+
198+
while (i <= d.nbDims) {
199+
if (j != pos) {
200+
dims.d[j] = d.d[i];
201+
} else {
202+
// add new dimension at pos
203+
i++;
204+
if (i <= d.nbDims) {
205+
dims.d[j] = d.d[i];
206+
}
207+
}
208+
i++;
209+
j++;
210+
}
211+
212+
dims.nbDims = d.nbDims - 1;
213+
214+
return dims;
215+
}
216+
189217
std::vector<int64_t> toVec(nvinfer1::Dims d) {
190218
std::vector<int64_t> dims;
191219
for (int i = 0; i < d.nbDims; i++) {

core/util/trt_util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to);
9494
nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to);
9595
nvinfer1::Dims unpadDims(const nvinfer1::Dims& d);
9696
nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos);
97+
nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos);
9798
nvinfer1::Dims toDims(c10::IntArrayRef l);
9899
nvinfer1::Dims toDims(c10::List<int64_t> l);
99100
nvinfer1::DimsHW toDimsHW(c10::List<int64_t> l);

tests/core/conversion/converters/BUILD

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ converter_test(
7171
name = "test_lstm_cell"
7272
)
7373

74+
converter_test(
75+
name = "test_unsqueeze"
76+
)
77+
78+
converter_test(
79+
name = "test_squeeze"
80+
)
81+
7482
test_suite(
7583
name = "converter_tests",
7684
tests = [
@@ -88,6 +96,8 @@ test_suite(
8896
":test_interpolate",
8997
":test_select",
9098
":test_stack",
91-
":test_lstm_cell"
99+
":test_lstm_cell",
100+
":test_unsqueeze",
101+
":test_squeeze"
92102
]
93103
)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "gtest/gtest.h"
4+
#include "tests/util/util.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
6+
7+
TEST(Converters, ATenSqueezeConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%1 : int = prim::Constant[value=1]()
11+
%2 : Tensor = aten::squeeze(%0, %1)
12+
return (%2))IR";
13+
14+
auto g = std::make_shared<torch::jit::Graph>();
15+
torch::jit::parseIR(graph, &*g);
16+
17+
auto in = at::randint(1, 10, {2, 1, 3, 3}, {at::kCUDA});
18+
19+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
20+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
21+
22+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
23+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
24+
25+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
26+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "gtest/gtest.h"
4+
#include "tests/util/util.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
6+
7+
TEST(Converters, ATenUnsqueezeConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%1 : int = prim::Constant[value=2]()
11+
%2 : Tensor = aten::unsqueeze(%0, %1)
12+
return (%2))IR";
13+
14+
auto g = std::make_shared<torch::jit::Graph>();
15+
torch::jit::parseIR(graph, &*g);
16+
17+
auto in = at::randint(1, 10, {2, 3, 3}, {at::kCUDA});
18+
19+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
20+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
21+
22+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
23+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
24+
25+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
26+
}

0 commit comments

Comments
 (0)