Skip to content

Commit 5b8b819

Browse files
authored
Merge pull request #425 from NVIDIA/plugins
feat(core/plugins): Plugins redesign
2 parents f053d32 + eb6ed7b commit 5b8b819

28 files changed

+1113
-245
lines changed

BUILD

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ pkg_tar(
1515
"//core/conversion:include",
1616
"//core/conversion/conversionctx:include",
1717
"//core/conversion/converters:include",
18-
"//core/conversion/converters/impl/plugins:include",
19-
"//core/conversion/evaluators:include",
20-
"//core/conversion/tensorcontainer:include",
2118
"//core/conversion/var:include",
19+
"//core/conversion/tensorcontainer:include",
20+
"//core/conversion/evaluators:include",
21+
"//core/plugins:include",
2222
"//core/lowering:include",
2323
"//core/lowering/passes:include",
2424
"//core/runtime:include",
@@ -42,6 +42,7 @@ pkg_tar(
4242
"//conditions:default": [
4343
"//cpp/api/lib:libtrtorch.so",
4444
"//cpp/api/lib:libtrtorchrt.so",
45+
"//cpp/api/lib:libtrtorch_plugins.so",
4546
],
4647
}),
4748
mode = "0755",

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1+
#include "core/conversion/conversionctx/ConversionCtx.h"
12
#include <iostream>
23
#include <sstream>
34
#include <utility>
45

5-
#include "core/conversion/conversionctx/ConversionCtx.h"
6-
76
namespace trtorch {
87
namespace core {
98
namespace conversion {

core/conversion/converters/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ cc_library(
4343
"impl/linear.cpp",
4444
"impl/lstm_cell.cpp",
4545
"impl/matrix_multiply.cpp",
46+
"impl/normalize.cpp",
4647
"impl/pooling.cpp",
4748
"impl/reduce.cpp",
4849
"impl/replication_pad.cpp",
@@ -65,7 +66,7 @@ cc_library(
6566
"//core/conversion/var",
6667
"//core/conversion/tensorcontainer",
6768
"//core/conversion/conversionctx",
68-
"//core/conversion/converters/impl/plugins",
69+
"//core/plugins:trtorch_plugins",
6970
] + select({
7071
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
7172
"//conditions:default": ["@libtorch//:libtorch"],

core/conversion/converters/impl/activation.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ auto acthardtanh TRTORCH_UNUSED =
146146

147147
auto new_layer = ctx->net->addActivation(*self, nvinfer1::ActivationType::kLEAKY_RELU);
148148
new_layer->setAlpha(negative_slopeScalar);
149-
150149
new_layer->setName(util::node_info(n).c_str());
151150
auto out_tensor = new_layer->getOutput(0);
152151
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
@@ -167,6 +166,35 @@ auto acthardtanh TRTORCH_UNUSED =
167166
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
168167
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
169168
return true;
169+
}})
170+
.pattern({"aten::gelu(Tensor self) -> (Tensor)",
171+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
172+
auto in = args[0].ITensorOrFreeze(ctx);
173+
nvinfer1::DataType type = in->getType();
174+
TRTORCH_CHECK(
175+
type == nvinfer1::DataType::kFLOAT || type == nvinfer1::DataType::kHALF,
176+
"gelu only supports kFLOAT and kHALF");
177+
std::string pluginName = "CustomGeluPluginDynamic";
178+
nvinfer1::PluginFieldCollection fc;
179+
std::vector<nvinfer1::PluginField> f;
180+
int type_id = ctx->settings.op_precision == nvinfer1::DataType::kFLOAT
181+
? 0
182+
: 1; // Integer encoding the DataType (0: FP32, 1: FP16)
183+
f.emplace_back(nvinfer1::PluginField("type_id", &type_id, nvinfer1::PluginFieldType::kINT32, 1));
184+
fc.nbFields = f.size();
185+
fc.fields = f.data();
186+
187+
auto creator = getPluginRegistry()->getPluginCreator("CustomGeluPluginDynamic", "1", "");
188+
auto gelu_plugin = creator->createPlugin("gelu", &fc);
189+
190+
TRTORCH_CHECK(gelu_plugin, "Unable to create gelu plugin from TensorRT plugin registry" << *n);
191+
auto new_layer =
192+
ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *gelu_plugin);
193+
new_layer->setName("gelu");
194+
auto out_tensor = new_layer->getOutput(0);
195+
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
196+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
197+
return true;
170198
}});
171199

172200
} // namespace

core/conversion/converters/impl/instance_norm.cpp

Lines changed: 0 additions & 118 deletions
This file was deleted.

core/conversion/converters/impl/interpolate.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#include "NvInferRuntimeCommon.h"
33
#include "core/conversion/converters/converters.h"
44
#include "core/util/prelude.h"
5-
#include "plugins/interpolate_plugin.h"
65
#include "torch/torch.h"
76

87
namespace trtorch {
@@ -28,11 +27,36 @@ void create_plugin(
2827
bool align_corners,
2928
bool use_scales = false) {
3029
LOG_WARNING("Interpolation layer will be run through ATen, not TensorRT. Performance may be lower than expected");
30+
nvinfer1::PluginFieldCollection fc;
31+
std::vector<nvinfer1::PluginField> f;
3132

32-
auto creator = new plugins::InterpolatePluginCreator();
33-
auto plugin = creator->createPlugin(name, in_shape, out_shape, out_size, scales, mode, align_corners, use_scales);
33+
std::vector<int32_t> in_shape_casted(in_shape.begin(), in_shape.end());
34+
f.emplace_back(
35+
nvinfer1::PluginField("in_shape", in_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, in_shape.size()));
3436

35-
auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
37+
std::vector<int32_t> out_shape_casted(out_shape.begin(), out_shape.end());
38+
f.emplace_back(
39+
nvinfer1::PluginField("out_shape", out_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, out_shape.size()));
40+
41+
std::vector<int32_t> out_size_casted(out_size.begin(), out_size.end());
42+
f.emplace_back(
43+
nvinfer1::PluginField("out_size", out_size_casted.data(), nvinfer1::PluginFieldType::kINT32, out_size.size()));
44+
45+
f.emplace_back(nvinfer1::PluginField("scales", scales.data(), nvinfer1::PluginFieldType::kFLOAT64, scales.size()));
46+
f.emplace_back(nvinfer1::PluginField("mode", &mode, nvinfer1::PluginFieldType::kCHAR, 1));
47+
48+
int32_t align_corners_casted = static_cast<int32_t>(align_corners);
49+
f.emplace_back(nvinfer1::PluginField("align_corners", &align_corners_casted, nvinfer1::PluginFieldType::kINT32, 1));
50+
51+
int32_t use_scales_casted = static_cast<int32_t>(use_scales);
52+
f.emplace_back(nvinfer1::PluginField("use_scales", &use_scales_casted, nvinfer1::PluginFieldType::kINT32, 1));
53+
54+
fc.nbFields = f.size();
55+
fc.fields = f.data();
56+
auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "trtorch");
57+
auto interpolate_plugin = creator->createPlugin(name, &fc);
58+
59+
auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *interpolate_plugin);
3660
TRTORCH_CHECK(resize_layer, "Unable to create interpolation plugin from node" << *n);
3761

3862
resize_layer->setName(util::node_info(n).c_str());
@@ -779,4 +803,4 @@ auto interpolate_registrations TRTORCH_UNUSED =
779803
} // namespace converters
780804
} // namespace conversion
781805
} // namespace core
782-
} // namespace trtorch
806+
} // namespace trtorch
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#include "NvInfer.h"
2+
#include "NvInferRuntimeCommon.h"
3+
#include "core/conversion/converters/converters.h"
4+
#include "core/util/prelude.h"
5+
#include "torch/torch.h"
6+
7+
namespace trtorch {
8+
namespace core {
9+
namespace conversion {
10+
namespace converters {
11+
namespace impl {
12+
namespace {
13+
14+
/*
15+
* Helper functions
16+
*/
17+
void create_plugin(
18+
ConversionCtx* ctx,
19+
const torch::jit::Node* n,
20+
nvinfer1::ITensor* in,
21+
int64_t order,
22+
std::vector<int32_t> axes,
23+
bool keep_dims,
24+
const char* name) {
25+
LOG_WARNING("Normalize layer will be run through ATen, not TensorRT. Performance may be lower than expected");
26+
nvinfer1::PluginFieldCollection fc;
27+
std::vector<nvinfer1::PluginField> f;
28+
f.emplace_back(nvinfer1::PluginField("order", &order, nvinfer1::PluginFieldType::kINT32, 1));
29+
f.emplace_back(nvinfer1::PluginField("axes", axes.data(), nvinfer1::PluginFieldType::kINT32, axes.size()));
30+
f.emplace_back(nvinfer1::PluginField("keep_dims", &keep_dims, nvinfer1::PluginFieldType::kINT32, 1));
31+
fc.nbFields = f.size();
32+
fc.fields = f.data();
33+
34+
auto inputnbDims = in->getDimensions().nbDims;
35+
for (int64_t i = 0; i < (int64_t)axes.size(); i++) {
36+
if (axes[i] < 0) {
37+
axes[i] += inputnbDims;
38+
}
39+
if (axes[i] > inputnbDims - 1) {
40+
TRTORCH_THROW_ERROR("Axis of normalization layer cannot exceed input rank");
41+
}
42+
}
43+
44+
auto creator = getPluginRegistry()->getPluginCreator("NormalizePlugin", "1", "trtorch");
45+
auto plugin = creator->createPlugin(name, &fc);
46+
auto normalize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
47+
TRTORCH_CHECK(normalize_layer, "Unable to create normalization plugin from node" << *n);
48+
49+
normalize_layer->setName(util::node_info(n).c_str());
50+
51+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], normalize_layer->getOutput(0));
52+
53+
LOG_DEBUG("Normalize layer output tensor shape: " << layer_output->getDimensions());
54+
}
55+
56+
auto normalize_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
57+
{"aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)",
58+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
59+
auto in = args[0].ITensor();
60+
auto in_shape = util::toVec(in->getDimensions());
61+
auto order = args[1].unwrapToScalar().to<int32_t>();
62+
auto axes_values = args[2].unwrapToIntList().vec();
63+
std::vector<int32_t> axes(axes_values.begin(), axes_values.end());
64+
auto keep_dims = (int32_t)args[3].unwrapToBool();
65+
LOG_DEBUG("Order of normalize_plugin: " << order);
66+
LOG_DEBUG("Axis: " << axes);
67+
LOG_DEBUG("keep_dims: " << keep_dims);
68+
create_plugin(ctx, n, in, order, axes, keep_dims, "NormalizePlugintrtorch");
69+
return true;
70+
}
71+
72+
});
73+
74+
} // namespace
75+
} // namespace impl
76+
} // namespace converters
77+
} // namespace conversion
78+
} // namespace core
79+
} // namespace trtorch

core/conversion/converters/impl/plugins/BUILD

Lines changed: 0 additions & 41 deletions
This file was deleted.

0 commit comments

Comments
 (0)