@@ -17,9 +17,40 @@ InterpolatePlugin::InterpolatePlugin(
17
17
std::vector<int64_t > in_shape,
18
18
std::vector<int64_t > out_shape,
19
19
std::vector<int64_t > size,
20
+ std::vector<double > scales,
20
21
std::string mode,
21
- bool align_corners)
22
- : in_shape_(in_shape), out_shape_(out_shape), size_(size), mode_(mode), align_corners_(align_corners) {}
22
+ bool align_corners,
23
+ bool use_scales)
24
+ : in_shape_(in_shape),
25
+ out_shape_ (out_shape),
26
+ size_(size),
27
+ scales_(scales),
28
+ mode_(mode),
29
+ align_corners_(align_corners),
30
+ use_scales_(use_scales) {
31
+ if (use_scales) {
32
+ TRTORCH_ASSERT (mode_ != " adaptive_pool2d" , " use_scales is not valid for adaptive_pool2d" );
33
+ TRTORCH_ASSERT (
34
+ scales_.size () != 0 , " Attempted to use interpolate plugin without providing scales while use_scales=true" );
35
+ at::Tensor input = at::randint (1 , 10 , in_shape, {at::kCUDA });
36
+ at::Tensor output;
37
+
38
+ if (mode_ == " linear" ) {
39
+ output = at::upsample_linear1d (input, c10::nullopt, align_corners_, scales_[0 ]);
40
+ } else if (mode_ == " bilinear" ) {
41
+ output = at::upsample_bilinear2d (input, c10::nullopt, align_corners_, scales_);
42
+ std::cout << output.sizes () << std::endl;
43
+ } else if (mode_ == " trilinear" ) {
44
+ output = at::upsample_trilinear3d (input, c10::nullopt, align_corners_, scales_);
45
+ }
46
+
47
+ out_shape_ = output.sizes ().vec ();
48
+ } else {
49
+ TRTORCH_ASSERT (
50
+ (size_.size () != 0 && out_shape_.size () != 0 ),
51
+ " Attempted to use interpolate plugin without providing output size while use_scales=false" );
52
+ }
53
+ }
23
54
24
55
InterpolatePlugin::InterpolatePlugin (const char * data, size_t length) {
25
56
std::istringstream data_stream (std::string (data, length));
@@ -42,6 +73,11 @@ InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
42
73
input_archive.read (" size" , value);
43
74
size_ = value.toIntVector ();
44
75
}
76
+ {
77
+ torch::IValue value;
78
+ input_archive.read (" scales" , value);
79
+ scales_ = value.toDoubleVector ();
80
+ }
45
81
{
46
82
torch::IValue value;
47
83
input_archive.read (" mode" , value);
@@ -52,6 +88,11 @@ InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
52
88
input_archive.read (" align_corners" , value);
53
89
align_corners_ = value.toBool ();
54
90
}
91
+ {
92
+ torch::IValue value;
93
+ input_archive.read (" use_scales" , value);
94
+ use_scales_ = value.toBool ();
95
+ }
55
96
}
56
97
57
98
std::vector<int64_t > InterpolatePlugin::getInputShape () {
@@ -83,7 +124,7 @@ const char* InterpolatePlugin::getPluginNamespace() const {
83
124
}
84
125
85
126
nvinfer1::IPluginV2DynamicExt* InterpolatePlugin::clone () const {
86
- return new InterpolatePlugin (in_shape_, out_shape_, size_, mode_, align_corners_);
127
+ return new InterpolatePlugin (in_shape_, out_shape_, size_, scales_, mode_, align_corners_, use_scales_ );
87
128
}
88
129
89
130
nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions (
@@ -93,9 +134,30 @@ nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(
93
134
nvinfer1::IExprBuilder& exprBuilder) {
94
135
nvinfer1::DimsExprs output (inputs[0 ]);
95
136
137
+ // TODO: This should enable the case of using this plugin with dynamic shape, scale factor and align corners == true
138
+ // to cover the different implementations between PyTorch and TRT. However TRT currently does not support doubles for
139
+ // ExprBuilder constants. Once that is possible enable this code and remove the code in the constructor if
140
+ // (use_scales_) {
141
+ // auto input_dimsexprs = inputs[0];
142
+ // output.d[0] = exprBuilder.operation(DimensionOperation::kMAX, *input_dimsexprs.d[0], *exprBuilder.constant(0));
143
+ // if (mode_ == "linear") {
144
+ // output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1],
145
+ // *exprBuilder.constant(scales_[1]));
146
+ // } else if (mode_ == "bilinear") {
147
+ // output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1],
148
+ // *exprBuilder.constant(scales_[1])); output.d[2] = exprBuilder.operation(DimensionOperation::kPROD,
149
+ // *input_dimsexprs.d[2], *exprBuilder.constant(scales_[2]));
150
+ // } else if (mode_ == "trilinear") {
151
+ // output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1],
152
+ // *exprBuilder.constant(scales_[1])); output.d[2] = exprBuilder.operation(DimensionOperation::kPROD,
153
+ // *input_dimsexprs.d[2], *exprBuilder.constant(scales_[2])); output.d[3] =
154
+ // exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[3], *exprBuilder.constant(scales_[3]));
155
+ // }
156
+ // } else {
96
157
for (unsigned int i = 0 ; i < out_shape_.size (); i++) {
97
158
output.d [i] = exprBuilder.constant (out_shape_[i]);
98
159
}
160
+ // }
99
161
100
162
return output;
101
163
}
@@ -131,8 +193,10 @@ std::string InterpolatePlugin::serializeToString() const {
131
193
output_archive.write (" in_shape" , torch::IValue (in_shape_));
132
194
output_archive.write (" out_shape" , torch::IValue (out_shape_));
133
195
output_archive.write (" size" , torch::IValue (size_));
196
+ output_archive.write (" scales" , torch::IValue (scales_));
134
197
output_archive.write (" mode" , torch::IValue (mode_));
135
198
output_archive.write (" align_corners" , torch::IValue (align_corners_));
199
+ output_archive.write (" use_scales" , torch::IValue (use_scales_));
136
200
137
201
std::ostringstream data_str;
138
202
output_archive.save_to (data_str);
@@ -201,14 +265,24 @@ int InterpolatePlugin::enqueue(
201
265
202
266
cudaStreamWaitEvent (torch_stream.stream (), event, 0 );
203
267
204
- if (mode_ == " linear" ) {
205
- at::upsample_linear1d_out (output, input, {size_[0 ]}, align_corners_);
206
- } else if (mode_ == " bilinear" ) {
207
- at::upsample_bilinear2d_out (output, input, {size_[0 ], size_[1 ]}, align_corners_);
208
- } else if (mode_ == " trilinear" ) {
209
- at::upsample_trilinear3d_out (output, input, {size_[0 ], size_[1 ], size_[2 ]}, align_corners_);
210
- } else if (mode_ == " adaptive_pool2d" ) {
211
- at::adaptive_avg_pool2d_out (output, input, {size_[0 ], size_[1 ]});
268
+ if (use_scales_) {
269
+ if (mode_ == " linear" ) {
270
+ at::upsample_linear1d_out (output, input, {}, align_corners_, scales_[0 ]);
271
+ } else if (mode_ == " bilinear" ) {
272
+ at::upsample_bilinear2d_out (output, input, {}, align_corners_, scales_[0 ], scales_[1 ]);
273
+ } else if (mode_ == " trilinear" ) {
274
+ at::upsample_trilinear3d_out (output, input, {}, align_corners_, scales_[0 ], scales_[1 ], scales_[2 ]);
275
+ }
276
+ } else {
277
+ if (mode_ == " linear" ) {
278
+ at::upsample_linear1d_out (output, input, {size_[0 ]}, align_corners_);
279
+ } else if (mode_ == " bilinear" ) {
280
+ at::upsample_bilinear2d_out (output, input, {size_[0 ], size_[1 ]}, align_corners_);
281
+ } else if (mode_ == " trilinear" ) {
282
+ at::upsample_trilinear3d_out (output, input, {size_[0 ], size_[1 ], size_[2 ]}, align_corners_);
283
+ } else if (mode_ == " adaptive_pool2d" ) {
284
+ at::adaptive_avg_pool2d_out (output, input, {size_[0 ], size_[1 ]});
285
+ }
212
286
}
213
287
214
288
cudaEvent_t torch_event;
@@ -235,10 +309,25 @@ int InterpolatePlugin::enqueue(
235
309
cudaStreamSynchronize (stream);
236
310
237
311
at::Tensor input = at::from_blob ((void *)input_blob, util::toVec (inputDesc->dims ), tensor_options_);
238
-
239
312
at::Tensor output;
240
- if (mode_ == " adaptive_pool2d" ) {
241
- output = at::adaptive_avg_pool2d (input, {size_[0 ], size_[1 ]});
313
+ if (use_scales_) {
314
+ if (mode_ == " linear" ) {
315
+ output = at::upsample_linear1d (input, c10::nullopt, align_corners_, {scales_[0 ]});
316
+ } else if (mode_ == " bilinear" ) {
317
+ output = at::upsample_bilinear2d (input, c10::nullopt, align_corners_, scales_);
318
+ } else if (mode_ == " trilinear" ) {
319
+ output = at::upsample_trilinear3d (input, c10::nullopt, align_corners_, scales_);
320
+ }
321
+ } else {
322
+ if (mode_ == " linear" ) {
323
+ output = at::upsample_linear1d (input, {size_[0 ]}, align_corners_);
324
+ } else if (mode_ == " bilinear" ) {
325
+ output = at::upsample_bilinear2d (input, {size_[0 ], size_[1 ]}, align_corners_);
326
+ } else if (mode_ == " trilinear" ) {
327
+ output = at::upsample_trilinear3d (input, {size_[0 ], size_[1 ], size_[2 ]}, align_corners_);
328
+ } else if (mode_ == " adaptive_pool2d" ) {
329
+ output = at::adaptive_avg_pool2d (input, {size_[0 ], size_[1 ]});
330
+ }
242
331
}
243
332
244
333
cudaMemcpyAsync (
@@ -277,10 +366,12 @@ InterpolatePlugin* InterpolatePluginCreator::createPlugin(
277
366
std::vector<int64_t > in_shape,
278
367
std::vector<int64_t > out_shape,
279
368
std::vector<int64_t > size,
369
+ std::vector<double > scales,
280
370
std::string mode,
281
- bool align_corners) {
371
+ bool align_corners,
372
+ bool use_scales) {
282
373
name_ = name;
283
- return new InterpolatePlugin (in_shape, out_shape, size, mode, align_corners);
374
+ return new InterpolatePlugin (in_shape, out_shape, size, scales, mode, align_corners, use_scales );
284
375
}
285
376
286
377
nvinfer1::IPluginV2* InterpolatePluginCreator::deserializePlugin (
0 commit comments