Skip to content

Commit aa131ac

Browse files
authored
Merge pull request #87 from abhi-iyer/master
Support for interpolation (aten::upsample_nearest)
2 parents d6c8d31 + 5ddab8b commit aa131ac

File tree

5 files changed

+273
-4
lines changed

5 files changed

+273
-4
lines changed

core/conversion/InterfaceTypes.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ InputRange::InputRange(std::vector<int64_t> min_shape, std::vector<int64_t> opt_
5555
<< max_shape.size() << ")");
5656
}
5757

58-
min = util::toDimsPad(min_shape, 4);
59-
opt = util::toDimsPad(opt_shape, 4);
60-
max = util::toDimsPad(max_shape, 4);
58+
min = util::toDims(min_shape);
59+
opt = util::toDims(opt_shape);
60+
max = util::toDims(max_shape);
6161

6262
std::vector<int64_t> dyn_shape;
6363
for (size_t i = 0; i < opt_shape.size(); i++) {
@@ -73,7 +73,7 @@ InputRange::InputRange(std::vector<int64_t> min_shape, std::vector<int64_t> opt_
7373
}
7474
}
7575

76-
input_shape = util::toDimsPad(dyn_shape, 4);
76+
input_shape = util::toDims(dyn_shape);
7777

7878
}
7979

core/conversion/converters/BUILD

100644100755
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ cc_library(
2828
"impl/shuffle.cpp",
2929
"impl/softmax.cpp",
3030
"impl/unary.cpp",
31+
"impl/interpolate.cpp"
3132
],
3233
deps = [
3334
"@tensorrt//:nvinfer",
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#include "torch/torch.h"
2+
#include "core/util/prelude.h"
3+
#include "core/conversion/converters/converters.h"
4+
5+
#include <csignal>
6+
7+
namespace trtorch {
8+
namespace core {
9+
namespace conversion {
10+
namespace converters {
11+
namespace impl {
12+
namespace {
13+
14+
auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
15+
.pattern({
16+
"aten::upsample_nearest1d(Tensor self, int[1] output_size, float? scales=None) -> (Tensor)",
17+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
18+
auto in = args[0].ITensor();
19+
auto in_shape = util::toVec(in->getDimensions());
20+
21+
// Case 1: user uses output size and not scales
22+
if (!args[1].IValue()->isNone() && args[2].IValue()->isNone()) {
23+
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
24+
25+
TRTORCH_ASSERT(out_size.size() == 1, "aten::upsample_nearest1d input Tensor and output size dimension mismatch");
26+
27+
auto out_shape = in_shape;
28+
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
29+
30+
auto resize_layer = ctx->net->addResize(*in);
31+
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
32+
33+
resize_layer->setOutputDimensions(util::toDims(out_shape));
34+
resize_layer->setResizeMode(nvinfer1::ResizeMode::kNEAREST);
35+
resize_layer->setName(util::node_info(n).c_str());
36+
37+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
38+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
39+
} else {
40+
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_nearest1d not supported yet.");
41+
}
42+
43+
return true;
44+
}
45+
}).pattern({
46+
"aten::upsample_nearest2d(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None) -> (Tensor)",
47+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
48+
auto in = args[0].ITensor();
49+
auto in_shape = util::toVec(in->getDimensions());
50+
51+
// Case 1: user uses output_size and not scales_h, scales_w
52+
if (!args[1].IValue()->isNone() && args[2].IValue()->isNone() && args[3].IValue()->isNone()){
53+
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
54+
55+
TRTORCH_ASSERT(out_size.size() == 2, "aten::upsample_nearest2d input Tensor and output size dimension mismatch");
56+
57+
auto out_shape = in_shape;
58+
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
59+
60+
auto resize_layer = ctx->net->addResize(*in);
61+
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
62+
63+
resize_layer->setOutputDimensions(util::toDims(out_shape));
64+
resize_layer->setResizeMode(nvinfer1::ResizeMode::kNEAREST);
65+
resize_layer->setName(util::node_info(n).c_str());
66+
67+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
68+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
69+
} else {
70+
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_nearest2d not supported yet.");
71+
}
72+
73+
return true;
74+
}
75+
}).pattern({
76+
"aten::upsample_nearest3d(Tensor self, int[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)",
77+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
78+
auto in = args[0].ITensor();
79+
auto in_shape = util::toVec(in->getDimensions());
80+
81+
// Case 1: user uses output size and not scales_d, scales_h, scales_w
82+
if (!args[1].IValue()->isNone() && args[2].IValue()->isNone() && args[3].IValue()->isNone() && args[4].IValue()->isNone()) {
83+
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
84+
85+
TRTORCH_ASSERT(out_size.size() == 3, "aten::upsample_nearest3d input Tensor and output size dimension mismatch");
86+
87+
auto out_shape = in_shape;
88+
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
89+
90+
auto resize_layer = ctx->net->addResize(*in);
91+
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
92+
93+
resize_layer->setOutputDimensions(util::toDims(out_shape));
94+
resize_layer->setResizeMode(nvinfer1::ResizeMode::kNEAREST);
95+
resize_layer->setName(util::node_info(n).c_str());
96+
97+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
98+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
99+
} else {
100+
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_nearest3d not supported yet.");
101+
}
102+
103+
return true;
104+
}
105+
});
106+
107+
108+
} // namespace
109+
} // namespace impl
110+
} // namespace converters
111+
} // namespace conversion
112+
} // namespace core
113+
} // namespace trtorch

tests/core/converters/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ converter_test(
5555
name = "test_unary"
5656
)
5757

58+
converter_test(
59+
name = "test_interpolate"
60+
)
61+
5862
test_suite(
5963
name = "test_converters",
6064
tests = [
@@ -69,6 +73,7 @@ test_suite(
6973
":test_shuffle",
7074
":test_softmax",
7175
":test_unary",
76+
":test_interpolate",
7277
]
7378
)
7479

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "torch/csrc/jit/ir/irparser.h"
4+
#include "tests/util/util.h"
5+
#include "core/compiler.h"
6+
7+
TEST(Converters, ATenUpsampleNearest1dConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%2 : int = prim::Constant[value=10]()
11+
%3 : int[] = prim::ListConstruct(%2)
12+
%4 : None = prim::Constant()
13+
%5 : Tensor = aten::upsample_nearest1d(%0, %3, %4)
14+
return (%5))IR";
15+
16+
auto g = std::make_shared<torch::jit::Graph>();
17+
18+
torch::jit::parseIR(graph, &*g);
19+
20+
// Input Tensor needs to be 3D for TensorRT upsample_nearest1d
21+
auto in = at::randint(1, 10, {10, 2, 2}, {at::kCUDA});
22+
23+
auto jit_in = at::clone(in);
24+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
25+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
26+
27+
auto trt_in = at::clone(in);
28+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
29+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
30+
31+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
32+
33+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
34+
}
35+
36+
TEST(Converters, ATenUpsampleNearest2dConvertsCorrectly1dOutputSize) {
37+
const auto graph = R"IR(
38+
graph(%0 : Tensor):
39+
%2 : int = prim::Constant[value=10]()
40+
%3 : int[] = prim::ListConstruct(%2, %2)
41+
%4 : None = prim::Constant()
42+
%5 : Tensor = aten::upsample_nearest2d(%0, %3, %4, %4)
43+
return (%5))IR";
44+
45+
auto g = std::make_shared<torch::jit::Graph>();
46+
47+
torch::jit::parseIR(graph, &*g);
48+
49+
// Input Tensor needs to be 4D for TensorRT upsample_nearest2d
50+
auto in = at::randint(1, 10, {10, 2, 2, 2}, {at::kCUDA});
51+
52+
auto jit_in = at::clone(in);
53+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
54+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
55+
56+
auto trt_in = at::clone(in);
57+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
58+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
59+
60+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
61+
62+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
63+
}
64+
65+
TEST(Converters, ATenUpsampleNearest2dConvertsCorrectly2dOutputSize) {
66+
const auto graph = R"IR(
67+
graph(%0 : Tensor):
68+
%2 : int = prim::Constant[value=10]()
69+
%3 : int[] = prim::ListConstruct(%2, %2)
70+
%4 : None = prim::Constant()
71+
%5 : Tensor = aten::upsample_nearest2d(%0, %3, %4, %4)
72+
return (%5))IR";
73+
74+
auto g = std::make_shared<torch::jit::Graph>();
75+
76+
torch::jit::parseIR(graph, &*g);
77+
78+
// Input Tensor needs to be 4D for TensorRT upsample_nearest2d
79+
auto in = at::randint(1, 10, {10, 2, 2, 2}, {at::kCUDA});
80+
81+
auto jit_in = at::clone(in);
82+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
83+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
84+
85+
auto trt_in = at::clone(in);
86+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
87+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
88+
89+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
90+
91+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
92+
}
93+
94+
TEST(Converters, ATenUpsampleNearest3dConvertsCorrectly1dOutputSize) {
95+
const auto graph = R"IR(
96+
graph(%0 : Tensor):
97+
%2 : int = prim::Constant[value=10]()
98+
%3 : int[] = prim::ListConstruct(%2, %2, %2)
99+
%4 : None = prim::Constant()
100+
%5 : Tensor = aten::upsample_nearest3d(%0, %3, %4, %4, %4)
101+
return (%5))IR";
102+
103+
auto g = std::make_shared<torch::jit::Graph>();
104+
105+
torch::jit::parseIR(graph, &*g);
106+
107+
// Input Tensor needs to be 5D for TensorRT upsample_nearest3d
108+
auto in = at::randint(1, 10, {10, 2, 2, 2, 2}, {at::kCUDA});
109+
110+
auto jit_in = at::clone(in);
111+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
112+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
113+
114+
auto trt_in = at::clone(in);
115+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
116+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
117+
118+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
119+
120+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
121+
}
122+
123+
TEST(Converters, ATenUpsampleNearest3dConvertsCorrectly3dOutputSize) {
124+
const auto graph = R"IR(
125+
graph(%0 : Tensor):
126+
%2 : int = prim::Constant[value=10]()
127+
%3 : int[] = prim::ListConstruct(%2, %2, %2)
128+
%4 : None = prim::Constant()
129+
%5 : Tensor = aten::upsample_nearest3d(%0, %3, %4, %4, %4)
130+
return (%5))IR";
131+
132+
auto g = std::make_shared<torch::jit::Graph>();
133+
134+
torch::jit::parseIR(graph, &*g);
135+
136+
// Input Tensor needs to be 5D for TensorRT upsample_nearest3d
137+
auto in = at::randint(1, 10, {10, 2, 2, 2, 2}, {at::kCUDA});
138+
139+
auto jit_in = at::clone(in);
140+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
141+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
142+
143+
auto trt_in = at::clone(in);
144+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
145+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
146+
147+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
148+
149+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
150+
}

0 commit comments

Comments
 (0)