|
1 | 1 | #include "core/conversion/converters/converter_util.h"
|
2 | 2 | #include "core/conversion/converters/converters.h"
|
3 | 3 | #include "core/util/prelude.h"
|
4 |
| -#include "plugins/interpolate_plugin.h" |
5 | 4 |
|
6 | 5 | namespace trtorch {
|
7 | 6 | namespace core {
|
@@ -74,21 +73,42 @@ bool AdaptivePoolingConverter(
|
74 | 73 | pool_type == nvinfer1::PoolingType::kAVERAGE,
|
75 | 74 | "Unable to create MAX pooling (interpolation) plugin from node" << *n);
|
76 | 75 |
|
| 76 | + nvinfer1::PluginFieldCollection fc; |
| 77 | + std::vector<nvinfer1::PluginField> f; |
| 78 | + |
77 | 79 | auto out_shape = in_shape;
|
78 | 80 | std::copy_n(out_size.d, out_size.nbDims, out_shape.begin() + (in_shape.size() - out_size.nbDims));
|
79 | 81 |
|
80 |
| - auto creator = new plugins::InterpolatePluginCreator(); |
81 |
| - auto plugin = creator->createPlugin( |
82 |
| - "adaptive_pool2d", |
83 |
| - in_shape, |
84 |
| - out_shape, |
85 |
| - util::toVec(out_size), |
86 |
| - {}, |
87 |
| - std::string("adaptive_pool2d"), |
88 |
| - false, |
89 |
| - false); |
90 |
| - |
91 |
| - new_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin); |
| 82 | + std::vector<int32_t> in_shape_casted(in_shape.begin(), in_shape.end()); |
| 83 | + f.emplace_back( |
| 84 | + nvinfer1::PluginField("in_shape", in_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, in_shape.size())); |
| 85 | + |
| 86 | + std::vector<int32_t> out_shape_casted(out_shape.begin(), out_shape.end()); |
| 87 | + f.emplace_back( |
| 88 | + nvinfer1::PluginField("out_shape", out_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, out_shape.size())); |
| 89 | + |
| 90 | + auto out_size_vec = util::toVec(out_size); |
| 91 | + std::vector<int32_t> out_size_casted(out_size_vec.begin(), out_size_vec.end()); |
| 92 | + f.emplace_back( |
| 93 | + nvinfer1::PluginField("out_size", out_size_casted.data(), nvinfer1::PluginFieldType::kINT32, out_size_vec.size())); |
| 94 | + |
| 95 | + f.emplace_back(nvinfer1::PluginField("scales", nullptr, nvinfer1::PluginFieldType::kFLOAT64, 0)); |
| 96 | + |
| 97 | + std::string mode = "adaptive_pool2d"; |
| 98 | + f.emplace_back(nvinfer1::PluginField("mode", &mode, nvinfer1::PluginFieldType::kCHAR, 1)); |
| 99 | + |
| 100 | + int32_t align_corners_casted = 0; |
| 101 | + f.emplace_back(nvinfer1::PluginField("align_corners", &align_corners_casted, nvinfer1::PluginFieldType::kINT32, 1)); |
| 102 | + |
| 103 | + int32_t use_scales_casted = 0; |
| 104 | + f.emplace_back(nvinfer1::PluginField("use_scales", &use_scales_casted, nvinfer1::PluginFieldType::kINT32, 1)); |
| 105 | + |
| 106 | + fc.nbFields = f.size(); |
| 107 | + fc.fields = f.data(); |
| 108 | + auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "trtorch"); |
| 109 | + auto interpolate_plugin = creator->createPlugin("adaptive_pool2d", &fc); |
| 110 | + |
| 111 | + new_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *interpolate_plugin); |
92 | 112 | TRTORCH_CHECK(new_layer, "Unable to create pooling (interpolation) plugin from node" << *n);
|
93 | 113 |
|
94 | 114 | } else {
|
|
0 commit comments