Skip to content

Commit bb668db

Browse files
authored
Merge pull request #446 from NVIDIA/yutec/layer_norm_elementwise_util
Add aten::layer_norm support and move add_elementwise to utils
2 parents 52947fe + 3dc9190 commit bb668db

File tree

7 files changed

+331
-59
lines changed

7 files changed

+331
-59
lines changed

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ cc_library(
4040
"impl/element_wise.cpp",
4141
"impl/expand.cpp",
4242
"impl/interpolate.cpp",
43+
"impl/layer_norm.cpp",
4344
"impl/linear.cpp",
4445
"impl/lstm_cell.cpp",
4546
"impl/matrix_multiply.cpp",

core/conversion/converters/converter_util.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "core/conversion/converters/converter_util.h"
22
#include "core/conversion/converters/converters.h"
33
#include "core/util/prelude.h"
4+
#include "torch/torch.h"
45

56
namespace trtorch {
67
namespace core {
@@ -59,6 +60,68 @@ nvinfer1::ITensor* addUnpadding(
5960
}
6061
}
6162

63+
nvinfer1::ILayer* add_elementwise(
64+
ConversionCtx* ctx,
65+
nvinfer1::ElementWiseOperation op,
66+
nvinfer1::ITensor* self,
67+
nvinfer1::ITensor* other,
68+
const std::string& name) {
69+
// ensure self to have larger number of dimension
70+
bool swapSelfOther = false;
71+
if (self->getDimensions().nbDims < other->getDimensions().nbDims) {
72+
std::swap(self, other);
73+
swapSelfOther = true;
74+
}
75+
auto selfDim = util::toVec(self->getDimensions());
76+
auto otherDim = util::toVec(other->getDimensions());
77+
if (selfDim.size() != otherDim.size()) {
78+
// other is with dynamic shape, need to expand its dimension now and get its
79+
// shape at runtime
80+
if (otherDim.end() != std::find(otherDim.begin(), otherDim.end(), -1)) {
81+
auto thOtherStaticShapeMask = torch::ones(selfDim.size(), torch::kInt32);
82+
auto thOtherDynamicShapeMask = torch::zeros(selfDim.size(), torch::kInt32);
83+
for (size_t start = selfDim.size() - otherDim.size(), idx = 0; idx < otherDim.size(); ++idx) {
84+
if (-1 != otherDim[idx]) {
85+
thOtherStaticShapeMask[start + idx] = otherDim[idx];
86+
} else {
87+
thOtherStaticShapeMask[start + idx] = 0;
88+
thOtherDynamicShapeMask[start + idx] = 1;
89+
}
90+
}
91+
auto otherStaticShapeMask = tensor_to_const(ctx, thOtherStaticShapeMask);
92+
auto otherDynamicShapeMask = tensor_to_const(ctx, thOtherDynamicShapeMask);
93+
auto selfShape = ctx->net->addShape(*self)->getOutput(0);
94+
// size of dynamic dimension of other need to the same as that of
95+
// corresponding dimension of self
96+
auto otherDynamicShape =
97+
ctx->net->addElementWise(*selfShape, *otherDynamicShapeMask, nvinfer1::ElementWiseOperation::kPROD)
98+
->getOutput(0);
99+
auto targetOtherShape =
100+
ctx->net->addElementWise(*otherDynamicShape, *otherStaticShapeMask, nvinfer1::ElementWiseOperation::kSUM)
101+
->getOutput(0);
102+
103+
auto otherShuffle = ctx->net->addShuffle(*other);
104+
otherShuffle->setName(std::string("Reshape other tensor to have the same nDim as self for " + name).c_str());
105+
otherShuffle->setInput(1, *targetOtherShape);
106+
other = otherShuffle->getOutput(0);
107+
} else {
108+
// other is with static shape, expand dimension to make tow tensor have
109+
// the same number of dimension
110+
auto otherShuffle = ctx->net->addShuffle(*other);
111+
otherShuffle->setReshapeDimensions(util::toDimsPad(otherDim, selfDim.size()));
112+
other = otherShuffle->getOutput(0);
113+
}
114+
}
115+
if (swapSelfOther) {
116+
// swap back
117+
std::swap(self, other);
118+
swapSelfOther = false;
119+
}
120+
auto ele = ctx->net->addElementWise(*self, *other, op);
121+
ele->setName(name.c_str());
122+
return ele;
123+
}
124+
62125
} // namespace converters
63126
} // namespace conversion
64127
} // namespace core

core/conversion/converters/converter_util.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ nvinfer1::ITensor* addUnpadding(
3535
bool trailing = true,
3636
bool use_zeros = true);
3737

38+
nvinfer1::ILayer* add_elementwise(
39+
ConversionCtx* ctx,
40+
nvinfer1::ElementWiseOperation op,
41+
nvinfer1::ITensor* self,
42+
nvinfer1::ITensor* other,
43+
const std::string& name);
44+
3845
} // namespace converters
3946
} // namespace conversion
4047
} // namespace core

core/conversion/converters/impl/element_wise.cpp

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <torch/torch.h>
2+
#include "core/conversion/converters/converter_util.h"
23
#include "core/conversion/converters/converters.h"
34
#include "core/util/prelude.h"
45

@@ -9,65 +10,6 @@ namespace converters {
910
namespace impl {
1011
namespace {
1112

12-
nvinfer1::ILayer* add_elementwise(
13-
ConversionCtx* ctx,
14-
nvinfer1::ElementWiseOperation op,
15-
nvinfer1::ITensor* self,
16-
nvinfer1::ITensor* other,
17-
const std::string& name) {
18-
// ensure self to have larger number of dimension
19-
bool swapSelfOther = false;
20-
if (self->getDimensions().nbDims < other->getDimensions().nbDims) {
21-
std::swap(self, other);
22-
swapSelfOther = true;
23-
}
24-
auto selfDim = util::toVec(self->getDimensions());
25-
auto otherDim = util::toVec(other->getDimensions());
26-
if (selfDim.size() != otherDim.size()) {
27-
// other is with dynamic shape, need to expand its dimension now and get its shape at runtime
28-
if (otherDim.end() != std::find(otherDim.begin(), otherDim.end(), -1)) {
29-
auto thOtherStaticShapeMask = torch::ones(selfDim.size(), torch::kInt32);
30-
auto thOtherDynamicShapeMask = torch::zeros(selfDim.size(), torch::kInt32);
31-
for (size_t start = selfDim.size() - otherDim.size(), idx = 0; idx < otherDim.size(); ++idx) {
32-
if (-1 != otherDim[idx]) {
33-
thOtherStaticShapeMask[start + idx] = otherDim[idx];
34-
} else {
35-
thOtherStaticShapeMask[start + idx] = 0;
36-
thOtherDynamicShapeMask[start + idx] = 1;
37-
}
38-
}
39-
auto otherStaticShapeMask = tensor_to_const(ctx, thOtherStaticShapeMask);
40-
auto otherDynamicShapeMask = tensor_to_const(ctx, thOtherDynamicShapeMask);
41-
auto selfShape = ctx->net->addShape(*self)->getOutput(0);
42-
// size of dynamic dimension of other need to the same as that of corresponding dimension of self
43-
auto otherDynamicShape =
44-
ctx->net->addElementWise(*selfShape, *otherDynamicShapeMask, nvinfer1::ElementWiseOperation::kPROD)
45-
->getOutput(0);
46-
auto targetOtherShape =
47-
ctx->net->addElementWise(*otherDynamicShape, *otherStaticShapeMask, nvinfer1::ElementWiseOperation::kSUM)
48-
->getOutput(0);
49-
50-
auto otherShuffle = ctx->net->addShuffle(*other);
51-
otherShuffle->setName(std::string("Reshape other tensor to have the same nDim as self for " + name).c_str());
52-
otherShuffle->setInput(1, *targetOtherShape);
53-
other = otherShuffle->getOutput(0);
54-
} else {
55-
// other is with static shape, expand dimension to make tow tensor have the same number of dimension
56-
auto otherShuffle = ctx->net->addShuffle(*other);
57-
otherShuffle->setReshapeDimensions(util::toDimsPad(otherDim, selfDim.size()));
58-
other = otherShuffle->getOutput(0);
59-
}
60-
}
61-
if (swapSelfOther) {
62-
// swap back
63-
std::swap(self, other);
64-
swapSelfOther = false;
65-
}
66-
auto ele = ctx->net->addElementWise(*self, *other, op);
67-
ele->setName(name.c_str());
68-
return ele;
69-
}
70-
7113
nvinfer1::ITensor* clamp_util(
7214
ConversionCtx* ctx,
7315
const torch::jit::Node* n,
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#include "core/conversion/converters/converter_util.h"
2+
#include "core/conversion/converters/converters.h"
3+
#include "core/util/prelude.h"
4+
#include "torch/torch.h"
5+
6+
namespace trtorch {
7+
namespace core {
8+
namespace conversion {
9+
namespace converters {
10+
namespace impl {
11+
namespace {
12+
13+
auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern({
14+
R"SIG(aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? gamma, Tensor? beta,
15+
float eps, bool cudnn_enabled) -> (Tensor))SIG",
16+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
17+
auto input = args[0].ITensor(); // assumes non-static input Tensor
18+
auto orig_shape = input->getDimensions();
19+
auto shape = util::toVec(orig_shape);
20+
21+
/* Layer_Norm normalizes over last N dimensions.
22+
normalizaed_shape could be (C,H,W), (H,W), or (W). */
23+
auto normalized_shape = args[1].unwrapToIntList();
24+
auto normalized_shape_vec = util::toVec(util::toDims(normalized_shape));
25+
26+
// Unwrap eps.
27+
auto eps = args[4].unwrapToDouble();
28+
29+
LOG_DEBUG("cudnn disregarded");
30+
31+
// Set up axis_ask for E[x].
32+
uint32_t axis_mask = 0;
33+
for (size_t i = 0; i < normalized_shape_vec.size(); i++) {
34+
axis_mask |= 1 << (shape.size() - i - 1);
35+
}
36+
LOG_DEBUG("Axis Mask for E[x]" << std::bitset<32>(axis_mask));
37+
38+
// E[x]
39+
auto mean_expected = ctx->net->addReduce(*input, nvinfer1::ReduceOperation::kAVG, axis_mask, true);
40+
TRTORCH_CHECK(mean_expected, "Unable to create mean_expected from node: " << *n);
41+
mean_expected->setName((util::node_info(n) + "_mean_expected").c_str());
42+
auto mean_expected_out = mean_expected->getOutput(0);
43+
44+
// X-E[x]
45+
auto sub = add_elementwise(
46+
ctx, nvinfer1::ElementWiseOperation::kSUB, input, mean_expected_out, (util::node_info(n) + "_sub").c_str());
47+
TRTORCH_CHECK(sub, "Unable to create Sub layer from node: " << *n);
48+
sub->setName((util::node_info(n) + "_sub").c_str());
49+
auto xsubmean_out = sub->getOutput(0);
50+
51+
// Variance = mean(pow(xsubmean,2))
52+
float pow_scalar = 2;
53+
auto exponent = tensor_to_const(ctx, torch::tensor({pow_scalar}));
54+
auto pow = add_elementwise(
55+
ctx, nvinfer1::ElementWiseOperation::kPOW, xsubmean_out, exponent, (util::node_info(n) + "_pow").c_str());
56+
TRTORCH_CHECK(pow, "Unable to create Pow layer from node: " << *n);
57+
pow->setName((util::node_info(n) + "_pow").c_str());
58+
auto pow_out = pow->getOutput(0);
59+
60+
auto mean_var = ctx->net->addReduce(*pow_out, nvinfer1::ReduceOperation::kAVG, axis_mask, true);
61+
TRTORCH_CHECK(mean_var, "Unable to create mean_var from node: " << *n);
62+
mean_var->setName((util::node_info(n) + "_mean_var").c_str());
63+
auto mean_var_out = mean_var->getOutput(0);
64+
65+
// Variance + eps
66+
auto eps_tensor = tensor_to_const(ctx, torch::tensor({eps}));
67+
auto add = add_elementwise(
68+
ctx, nvinfer1::ElementWiseOperation::kSUM, mean_var_out, eps_tensor, (util::node_info(n) + "_add").c_str());
69+
TRTORCH_CHECK(add, "Unable to create Add layer from node: " << *n);
70+
add->setName((util::node_info(n) + "_add").c_str());
71+
auto add_out = add->getOutput(0);
72+
73+
// SQRT((Var + eps))
74+
auto sqrt = ctx->net->addUnary(*add_out, nvinfer1::UnaryOperation::kSQRT);
75+
TRTORCH_CHECK(sqrt, "Unable to create unary(sqrt) from node: " << *n);
76+
sqrt->setName((util::node_info(n) + "_sqrt").c_str());
77+
auto sqrt_out = sqrt->getOutput(0);
78+
79+
// (x - E[x]) / sqrt((var + eps))
80+
auto div = add_elementwise(
81+
ctx, nvinfer1::ElementWiseOperation::kDIV, xsubmean_out, sqrt_out, (util::node_info(n) + "_div").c_str());
82+
TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n);
83+
div->setName((util::node_info(n) + "_div").c_str());
84+
auto div_out = div->getOutput(0);
85+
86+
if (!args[2].IValue()->isTensor() && !args[3].IValue()->isTensor()) {
87+
ctx->AssociateValueAndTensor(n->outputs()[0], div_out);
88+
return true;
89+
}
90+
91+
// Remove batch dimension from input shape for expand_size, which will
92+
// be used to create weights for addScaleNd later.
93+
auto expand_size = shape;
94+
expand_size.erase(expand_size.begin(), expand_size.begin() + 1);
95+
96+
// Set up gamma_weights and beta_weights from gamma_expand and
97+
// beta_expand.
98+
auto gamma_weights = Weights(ctx, at::ones(expand_size));
99+
auto beta_weights = Weights(ctx, at::zeros(expand_size));
100+
101+
if (args[2].IValue()->isTensor()) {
102+
torch::Tensor gamma;
103+
gamma = args[2].unwrapToTensor();
104+
auto gamma_expand = gamma.expand(expand_size);
105+
gamma_weights = Weights(ctx, gamma_expand);
106+
} else {
107+
gamma_weights = Weights(ctx, at::ones(expand_size));
108+
}
109+
110+
if (args[3].IValue()->isTensor()) {
111+
torch::Tensor beta;
112+
beta = args[3].unwrapToTensor();
113+
auto beta_expand = beta.expand(expand_size);
114+
beta_weights = Weights(ctx, beta_expand);
115+
} else {
116+
beta_weights = Weights(ctx, at::zeros(expand_size));
117+
}
118+
119+
auto power = Weights(ctx, at::ones(expand_size));
120+
auto scale_nd = ctx->net->addScaleNd(
121+
*div_out, nvinfer1::ScaleMode::kELEMENTWISE, beta_weights.data, gamma_weights.data, power.data, 1);
122+
scale_nd->setName((util::node_info(n) + "_scale_nd").c_str());
123+
auto scale_nd_out = scale_nd->getOutput(0);
124+
125+
ctx->AssociateValueAndTensor(n->outputs()[0], scale_nd_out);
126+
return true;
127+
}});
128+
129+
} // namespace
130+
} // namespace impl
131+
} // namespace converters
132+
} // namespace conversion
133+
} // namespace core
134+
} // namespace trtorch

tests/core/conversion/converters/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ converter_test(
3535
name = "test_expand",
3636
)
3737

38+
converter_test(
39+
name = "test_layer_norm",
40+
)
41+
3842
converter_test(
3943
name = "test_linear",
4044
)
@@ -110,6 +114,7 @@ test_suite(
110114
":test_element_wise",
111115
":test_expand",
112116
":test_interpolate",
117+
":test_layer_norm",
113118
":test_linear",
114119
":test_lstm_cell",
115120
":test_matrix_multiply",

0 commit comments

Comments
 (0)