Skip to content

Commit a294f9e

Browse files
committed
Fix pooling
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 5a53c13 commit a294f9e

File tree

2 files changed

+34
-14
lines changed

2 files changed

+34
-14
lines changed

core/conversion/converters/impl/pooling.cpp

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include "core/conversion/converters/converter_util.h"
22
#include "core/conversion/converters/converters.h"
33
#include "core/util/prelude.h"
4-
#include "plugins/interpolate_plugin.h"
54

65
namespace trtorch {
76
namespace core {
@@ -74,21 +73,42 @@ bool AdaptivePoolingConverter(
7473
pool_type == nvinfer1::PoolingType::kAVERAGE,
7574
"Unable to create MAX pooling (interpolation) plugin from node" << *n);
7675

76+
nvinfer1::PluginFieldCollection fc;
77+
std::vector<nvinfer1::PluginField> f;
78+
7779
auto out_shape = in_shape;
7880
std::copy_n(out_size.d, out_size.nbDims, out_shape.begin() + (in_shape.size() - out_size.nbDims));
7981

80-
auto creator = new plugins::InterpolatePluginCreator();
81-
auto plugin = creator->createPlugin(
82-
"adaptive_pool2d",
83-
in_shape,
84-
out_shape,
85-
util::toVec(out_size),
86-
{},
87-
std::string("adaptive_pool2d"),
88-
false,
89-
false);
90-
91-
new_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
82+
std::vector<int32_t> in_shape_casted(in_shape.begin(), in_shape.end());
83+
f.emplace_back(
84+
nvinfer1::PluginField("in_shape", in_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, in_shape.size()));
85+
86+
std::vector<int32_t> out_shape_casted(out_shape.begin(), out_shape.end());
87+
f.emplace_back(
88+
nvinfer1::PluginField("out_shape", out_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, out_shape.size()));
89+
90+
auto out_size_vec = util::toVec(out_size);
91+
std::vector<int32_t> out_size_casted(out_size_vec.begin(), out_size_vec.end());
92+
f.emplace_back(
93+
nvinfer1::PluginField("out_size", out_size_casted.data(), nvinfer1::PluginFieldType::kINT32, out_size_vec.size()));
94+
95+
f.emplace_back(nvinfer1::PluginField("scales", nullptr, nvinfer1::PluginFieldType::kFLOAT64, 0));
96+
97+
std::string mode = "adaptive_pool2d";
98+
f.emplace_back(nvinfer1::PluginField("mode", &mode, nvinfer1::PluginFieldType::kCHAR, 1));
99+
100+
int32_t align_corners_casted = 0;
101+
f.emplace_back(nvinfer1::PluginField("align_corners", &align_corners_casted, nvinfer1::PluginFieldType::kINT32, 1));
102+
103+
int32_t use_scales_casted = 0;
104+
f.emplace_back(nvinfer1::PluginField("use_scales", &use_scales_casted, nvinfer1::PluginFieldType::kINT32, 1));
105+
106+
fc.nbFields = f.size();
107+
fc.fields = f.data();
108+
auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "trtorch");
109+
auto interpolate_plugin = creator->createPlugin("adaptive_pool2d", &fc);
110+
111+
new_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *interpolate_plugin);
92112
TRTORCH_CHECK(new_layer, "Unable to create pooling (interpolation) plugin from node" << *n);
93113

94114
} else {

core/plugins/impl/normalize_plugin.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ nvinfer1::DimsExprs NormalizePlugin::getOutputDimensions(
7373
// TODO: For dim=None, the axes_ passed would have [0, 0, 0] which is obtained through loop counter in TRTorch.
7474
// Resolve this. For dim=None case, change the axes_ inplace to range(0, axes_.size())
7575
bool isAxisNone =
76-
std::all_of(axes_.begin(), axes_.end(), [](int32_t i) { return i == 0; }) && (axes_.size() == inputs[0].nbDims);
76+
std::all_of(axes_.begin(), axes_.end(), [](int32_t i) { return i == 0; }) && ((int32_t) axes_.size() == inputs[0].nbDims);
7777
if (isAxisNone) {
7878
std::iota(axes_.data(), axes_.data() + axes_.size(), 0);
7979
}

0 commit comments

Comments
 (0)