Skip to content

Commit 8fb390d

Browse files
authored
Merge pull request #295 from NVIDIA/dynamic_interpolation
Extending @uni19's work on support dynamic shape input and scale_factor in interpolate layer
2 parents 08b2455 + 1781f25 commit 8fb390d

File tree

6 files changed

+1072
-352
lines changed

6 files changed

+1072
-352
lines changed

core/conversion/converters/impl/interpolate.cpp

Lines changed: 575 additions & 66 deletions
Large diffs are not rendered by default.

core/conversion/converters/impl/plugins/interpolate_plugin.cpp

Lines changed: 107 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,40 @@ InterpolatePlugin::InterpolatePlugin(
1717
std::vector<int64_t> in_shape,
1818
std::vector<int64_t> out_shape,
1919
std::vector<int64_t> size,
20+
std::vector<double> scales,
2021
std::string mode,
21-
bool align_corners)
22-
: in_shape_(in_shape), out_shape_(out_shape), size_(size), mode_(mode), align_corners_(align_corners) {}
22+
bool align_corners,
23+
bool use_scales)
24+
: in_shape_(in_shape),
25+
out_shape_(out_shape),
26+
size_(size),
27+
scales_(scales),
28+
mode_(mode),
29+
align_corners_(align_corners),
30+
use_scales_(use_scales) {
31+
if (use_scales) {
32+
TRTORCH_ASSERT(mode_ != "adaptive_pool2d", "use_scales is not valid for adaptive_pool2d");
33+
TRTORCH_ASSERT(
34+
scales_.size() != 0, "Attempted to use interpolate plugin without providing scales while use_scales=true");
35+
at::Tensor input = at::randint(1, 10, in_shape, {at::kCUDA});
36+
at::Tensor output;
37+
38+
if (mode_ == "linear") {
39+
output = at::upsample_linear1d(input, c10::nullopt, align_corners_, scales_[0]);
40+
} else if (mode_ == "bilinear") {
41+
output = at::upsample_bilinear2d(input, c10::nullopt, align_corners_, scales_);
42+
std::cout << output.sizes() << std::endl;
43+
} else if (mode_ == "trilinear") {
44+
output = at::upsample_trilinear3d(input, c10::nullopt, align_corners_, scales_);
45+
}
46+
47+
out_shape_ = output.sizes().vec();
48+
} else {
49+
TRTORCH_ASSERT(
50+
(size_.size() != 0 && out_shape_.size() != 0),
51+
"Attempted to use interpolate plugin without providing output size while use_scales=false");
52+
}
53+
}
2354

2455
InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
2556
std::istringstream data_stream(std::string(data, length));
@@ -42,6 +73,11 @@ InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
4273
input_archive.read("size", value);
4374
size_ = value.toIntVector();
4475
}
76+
{
77+
torch::IValue value;
78+
input_archive.read("scales", value);
79+
scales_ = value.toDoubleVector();
80+
}
4581
{
4682
torch::IValue value;
4783
input_archive.read("mode", value);
@@ -52,6 +88,11 @@ InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
5288
input_archive.read("align_corners", value);
5389
align_corners_ = value.toBool();
5490
}
91+
{
92+
torch::IValue value;
93+
input_archive.read("use_scales", value);
94+
use_scales_ = value.toBool();
95+
}
5596
}
5697

5798
std::vector<int64_t> InterpolatePlugin::getInputShape() {
@@ -83,7 +124,7 @@ const char* InterpolatePlugin::getPluginNamespace() const {
83124
}
84125

85126
nvinfer1::IPluginV2DynamicExt* InterpolatePlugin::clone() const {
86-
return new InterpolatePlugin(in_shape_, out_shape_, size_, mode_, align_corners_);
127+
return new InterpolatePlugin(in_shape_, out_shape_, size_, scales_, mode_, align_corners_, use_scales_);
87128
}
88129

89130
nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(
@@ -93,9 +134,30 @@ nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(
93134
nvinfer1::IExprBuilder& exprBuilder) {
94135
nvinfer1::DimsExprs output(inputs[0]);
95136

137+
// TODO: This should enable the case of using this plugin with dynamic shape, scale factor and align corners == true
138+
// to cover the different implementations between PyTorch and TRT. However TRT currently does not support doubles for
139+
// ExprBuilder constants. Once that is possible enable this code and remove the code in the constructor if
140+
// (use_scales_) {
141+
// auto input_dimsexprs = inputs[0];
142+
// output.d[0] = exprBuilder.operation(DimensionOperation::kMAX, *input_dimsexprs.d[0], *exprBuilder.constant(0));
143+
// if (mode_ == "linear") {
144+
// output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1],
145+
// *exprBuilder.constant(scales_[1]));
146+
// } else if (mode_ == "bilinear") {
147+
// output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1],
148+
// *exprBuilder.constant(scales_[1])); output.d[2] = exprBuilder.operation(DimensionOperation::kPROD,
149+
// *input_dimsexprs.d[2], *exprBuilder.constant(scales_[2]));
150+
// } else if (mode_ == "trilinear") {
151+
// output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1],
152+
// *exprBuilder.constant(scales_[1])); output.d[2] = exprBuilder.operation(DimensionOperation::kPROD,
153+
// *input_dimsexprs.d[2], *exprBuilder.constant(scales_[2])); output.d[3] =
154+
// exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[3], *exprBuilder.constant(scales_[3]));
155+
// }
156+
// } else {
96157
for (unsigned int i = 0; i < out_shape_.size(); i++) {
97158
output.d[i] = exprBuilder.constant(out_shape_[i]);
98159
}
160+
//}
99161

100162
return output;
101163
}
@@ -131,8 +193,10 @@ std::string InterpolatePlugin::serializeToString() const {
131193
output_archive.write("in_shape", torch::IValue(in_shape_));
132194
output_archive.write("out_shape", torch::IValue(out_shape_));
133195
output_archive.write("size", torch::IValue(size_));
196+
output_archive.write("scales", torch::IValue(scales_));
134197
output_archive.write("mode", torch::IValue(mode_));
135198
output_archive.write("align_corners", torch::IValue(align_corners_));
199+
output_archive.write("use_scales", torch::IValue(use_scales_));
136200

137201
std::ostringstream data_str;
138202
output_archive.save_to(data_str);
@@ -201,14 +265,24 @@ int InterpolatePlugin::enqueue(
201265

202266
cudaStreamWaitEvent(torch_stream.stream(), event, 0);
203267

204-
if (mode_ == "linear") {
205-
at::upsample_linear1d_out(output, input, {size_[0]}, align_corners_);
206-
} else if (mode_ == "bilinear") {
207-
at::upsample_bilinear2d_out(output, input, {size_[0], size_[1]}, align_corners_);
208-
} else if (mode_ == "trilinear") {
209-
at::upsample_trilinear3d_out(output, input, {size_[0], size_[1], size_[2]}, align_corners_);
210-
} else if (mode_ == "adaptive_pool2d") {
211-
at::adaptive_avg_pool2d_out(output, input, {size_[0], size_[1]});
268+
if (use_scales_) {
269+
if (mode_ == "linear") {
270+
at::upsample_linear1d_out(output, input, {}, align_corners_, scales_[0]);
271+
} else if (mode_ == "bilinear") {
272+
at::upsample_bilinear2d_out(output, input, {}, align_corners_, scales_[0], scales_[1]);
273+
} else if (mode_ == "trilinear") {
274+
at::upsample_trilinear3d_out(output, input, {}, align_corners_, scales_[0], scales_[1], scales_[2]);
275+
}
276+
} else {
277+
if (mode_ == "linear") {
278+
at::upsample_linear1d_out(output, input, {size_[0]}, align_corners_);
279+
} else if (mode_ == "bilinear") {
280+
at::upsample_bilinear2d_out(output, input, {size_[0], size_[1]}, align_corners_);
281+
} else if (mode_ == "trilinear") {
282+
at::upsample_trilinear3d_out(output, input, {size_[0], size_[1], size_[2]}, align_corners_);
283+
} else if (mode_ == "adaptive_pool2d") {
284+
at::adaptive_avg_pool2d_out(output, input, {size_[0], size_[1]});
285+
}
212286
}
213287

214288
cudaEvent_t torch_event;
@@ -235,10 +309,25 @@ int InterpolatePlugin::enqueue(
235309
cudaStreamSynchronize(stream);
236310

237311
at::Tensor input = at::from_blob((void*)input_blob, util::toVec(inputDesc->dims), tensor_options_);
238-
239312
at::Tensor output;
240-
if (mode_ == "adaptive_pool2d") {
241-
output = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
313+
if (use_scales_) {
314+
if (mode_ == "linear") {
315+
output = at::upsample_linear1d(input, c10::nullopt, align_corners_, {scales_[0]});
316+
} else if (mode_ == "bilinear") {
317+
output = at::upsample_bilinear2d(input, c10::nullopt, align_corners_, scales_);
318+
} else if (mode_ == "trilinear") {
319+
output = at::upsample_trilinear3d(input, c10::nullopt, align_corners_, scales_);
320+
}
321+
} else {
322+
if (mode_ == "linear") {
323+
output = at::upsample_linear1d(input, {size_[0]}, align_corners_);
324+
} else if (mode_ == "bilinear") {
325+
output = at::upsample_bilinear2d(input, {size_[0], size_[1]}, align_corners_);
326+
} else if (mode_ == "trilinear") {
327+
output = at::upsample_trilinear3d(input, {size_[0], size_[1], size_[2]}, align_corners_);
328+
} else if (mode_ == "adaptive_pool2d") {
329+
output = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
330+
}
242331
}
243332

244333
cudaMemcpyAsync(
@@ -277,10 +366,12 @@ InterpolatePlugin* InterpolatePluginCreator::createPlugin(
277366
std::vector<int64_t> in_shape,
278367
std::vector<int64_t> out_shape,
279368
std::vector<int64_t> size,
369+
std::vector<double> scales,
280370
std::string mode,
281-
bool align_corners) {
371+
bool align_corners,
372+
bool use_scales) {
282373
name_ = name;
283-
return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners);
374+
return new InterpolatePlugin(in_shape, out_shape, size, scales, mode, align_corners, use_scales);
284375
}
285376

286377
nvinfer1::IPluginV2* InterpolatePluginCreator::deserializePlugin(

core/conversion/converters/impl/plugins/interpolate_plugin.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
3131
std::vector<int64_t> in_shape_;
3232
std::vector<int64_t> out_shape_;
3333
std::vector<int64_t> size_;
34+
std::vector<double> scales_;
3435
std::string mode_;
3536
bool align_corners_;
37+
bool use_scales_;
3638

3739
protected:
3840
// To prevent compiler warnings
@@ -49,8 +51,10 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
4951
std::vector<int64_t> in_shape,
5052
std::vector<int64_t> out_shape,
5153
std::vector<int64_t> size,
54+
std::vector<double> scales,
5255
std::string mode,
53-
bool align_corners);
56+
bool align_corners,
57+
bool use_scales);
5458

5559
InterpolatePlugin(const char* data, size_t length);
5660

@@ -140,8 +144,10 @@ class InterpolatePluginCreator : public nvinfer1::IPluginCreator {
140144
std::vector<int64_t> in_shape,
141145
std::vector<int64_t> out_shape,
142146
std::vector<int64_t> size,
147+
std::vector<double> scales,
143148
std::string mode,
144-
bool align_corners);
149+
bool align_corners,
150+
bool use_scales);
145151

146152
nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
147153

core/conversion/converters/impl/pooling.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,14 @@ auto pooling_registrations TRTORCH_UNUSED =
317317

318318
auto creator = new plugins::InterpolatePluginCreator();
319319
auto plugin = creator->createPlugin(
320-
"adaptive_pool2d", in_shape, out_shape, out_size, std::string("adaptive_pool2d"), false);
320+
"adaptive_pool2d",
321+
in_shape,
322+
out_shape,
323+
out_size,
324+
{},
325+
std::string("adaptive_pool2d"),
326+
false,
327+
false);
321328

322329
auto pooling_layer =
323330
ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);

0 commit comments

Comments
 (0)