|
| 1 | +#include "core/conversion/converters/converter_util.h" |
| 2 | +#include "core/conversion/converters/converters.h" |
| 3 | +#include "core/util/prelude.h" |
| 4 | +#include "torch/torch.h" |
| 5 | + |
| 6 | +namespace trtorch { |
| 7 | +namespace core { |
| 8 | +namespace conversion { |
| 9 | +namespace converters { |
| 10 | +namespace impl { |
| 11 | +namespace { |
| 12 | + |
| 13 | +auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern({ |
| 14 | + R"SIG(aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? gamma, Tensor? beta, |
| 15 | + float eps, bool cudnn_enabled) -> (Tensor))SIG", |
| 16 | + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { |
| 17 | + auto input = args[0].ITensor(); // assumes non-static input Tensor |
| 18 | + auto orig_shape = input->getDimensions(); |
| 19 | + auto shape = util::toVec(orig_shape); |
| 20 | + |
| 21 | + /* Layer_Norm normalizes over last N dimensions. |
| 22 | + normalizaed_shape could be (C,H,W), (H,W), or (W). */ |
| 23 | + auto normalized_shape = args[1].unwrapToIntList(); |
| 24 | + auto normalized_shape_vec = util::toVec(util::toDims(normalized_shape)); |
| 25 | + |
| 26 | + // Unwrap eps. |
| 27 | + auto eps = args[4].unwrapToDouble(); |
| 28 | + |
| 29 | + LOG_DEBUG("cudnn disregarded"); |
| 30 | + |
| 31 | + // Set up axis_ask for E[x]. |
| 32 | + uint32_t axis_mask = 0; |
| 33 | + for (size_t i = 0; i < normalized_shape_vec.size(); i++) { |
| 34 | + axis_mask |= 1 << (shape.size() - i - 1); |
| 35 | + } |
| 36 | + LOG_DEBUG("Axis Mask for E[x]" << std::bitset<32>(axis_mask)); |
| 37 | + |
| 38 | + // E[x] |
| 39 | + auto mean_expected = ctx->net->addReduce(*input, nvinfer1::ReduceOperation::kAVG, axis_mask, true); |
| 40 | + TRTORCH_CHECK(mean_expected, "Unable to create mean_expected from node: " << *n); |
| 41 | + mean_expected->setName((util::node_info(n) + "_mean_expected").c_str()); |
| 42 | + auto mean_expected_out = mean_expected->getOutput(0); |
| 43 | + |
| 44 | + // X-E[x] |
| 45 | + auto sub = add_elementwise( |
| 46 | + ctx, nvinfer1::ElementWiseOperation::kSUB, input, mean_expected_out, (util::node_info(n) + "_sub").c_str()); |
| 47 | + TRTORCH_CHECK(sub, "Unable to create Sub layer from node: " << *n); |
| 48 | + sub->setName((util::node_info(n) + "_sub").c_str()); |
| 49 | + auto xsubmean_out = sub->getOutput(0); |
| 50 | + |
| 51 | + // Variance = mean(pow(xsubmean,2)) |
| 52 | + float pow_scalar = 2; |
| 53 | + auto exponent = tensor_to_const(ctx, torch::tensor({pow_scalar})); |
| 54 | + auto pow = add_elementwise( |
| 55 | + ctx, nvinfer1::ElementWiseOperation::kPOW, xsubmean_out, exponent, (util::node_info(n) + "_pow").c_str()); |
| 56 | + TRTORCH_CHECK(pow, "Unable to create Pow layer from node: " << *n); |
| 57 | + pow->setName((util::node_info(n) + "_pow").c_str()); |
| 58 | + auto pow_out = pow->getOutput(0); |
| 59 | + |
| 60 | + auto mean_var = ctx->net->addReduce(*pow_out, nvinfer1::ReduceOperation::kAVG, axis_mask, true); |
| 61 | + TRTORCH_CHECK(mean_var, "Unable to create mean_var from node: " << *n); |
| 62 | + mean_var->setName((util::node_info(n) + "_mean_var").c_str()); |
| 63 | + auto mean_var_out = mean_var->getOutput(0); |
| 64 | + |
| 65 | + // Variance + eps |
| 66 | + auto eps_tensor = tensor_to_const(ctx, torch::tensor({eps})); |
| 67 | + auto add = add_elementwise( |
| 68 | + ctx, nvinfer1::ElementWiseOperation::kSUM, mean_var_out, eps_tensor, (util::node_info(n) + "_add").c_str()); |
| 69 | + TRTORCH_CHECK(add, "Unable to create Add layer from node: " << *n); |
| 70 | + add->setName((util::node_info(n) + "_add").c_str()); |
| 71 | + auto add_out = add->getOutput(0); |
| 72 | + |
| 73 | + // SQRT((Var + eps)) |
| 74 | + auto sqrt = ctx->net->addUnary(*add_out, nvinfer1::UnaryOperation::kSQRT); |
| 75 | + TRTORCH_CHECK(sqrt, "Unable to create unary(sqrt) from node: " << *n); |
| 76 | + sqrt->setName((util::node_info(n) + "_sqrt").c_str()); |
| 77 | + auto sqrt_out = sqrt->getOutput(0); |
| 78 | + |
| 79 | + // (x - E[x]) / sqrt((var + eps)) |
| 80 | + auto div = add_elementwise( |
| 81 | + ctx, nvinfer1::ElementWiseOperation::kDIV, xsubmean_out, sqrt_out, (util::node_info(n) + "_div").c_str()); |
| 82 | + TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n); |
| 83 | + div->setName((util::node_info(n) + "_div").c_str()); |
| 84 | + auto div_out = div->getOutput(0); |
| 85 | + |
| 86 | + if (!args[2].IValue()->isTensor() && !args[3].IValue()->isTensor()) { |
| 87 | + ctx->AssociateValueAndTensor(n->outputs()[0], div_out); |
| 88 | + return true; |
| 89 | + } |
| 90 | + |
| 91 | + // Remove batch dimension from input shape for expand_size, which will |
| 92 | + // be used to create weights for addScaleNd later. |
| 93 | + auto expand_size = shape; |
| 94 | + expand_size.erase(expand_size.begin(), expand_size.begin() + 1); |
| 95 | + |
| 96 | + // Set up gamma_weights and beta_weights from gamma_expand and |
| 97 | + // beta_expand. |
| 98 | + auto gamma_weights = Weights(ctx, at::ones(expand_size)); |
| 99 | + auto beta_weights = Weights(ctx, at::zeros(expand_size)); |
| 100 | + |
| 101 | + if (args[2].IValue()->isTensor()) { |
| 102 | + torch::Tensor gamma; |
| 103 | + gamma = args[2].unwrapToTensor(); |
| 104 | + auto gamma_expand = gamma.expand(expand_size); |
| 105 | + gamma_weights = Weights(ctx, gamma_expand); |
| 106 | + } else { |
| 107 | + gamma_weights = Weights(ctx, at::ones(expand_size)); |
| 108 | + } |
| 109 | + |
| 110 | + if (args[3].IValue()->isTensor()) { |
| 111 | + torch::Tensor beta; |
| 112 | + beta = args[3].unwrapToTensor(); |
| 113 | + auto beta_expand = beta.expand(expand_size); |
| 114 | + beta_weights = Weights(ctx, beta_expand); |
| 115 | + } else { |
| 116 | + beta_weights = Weights(ctx, at::zeros(expand_size)); |
| 117 | + } |
| 118 | + |
| 119 | + auto power = Weights(ctx, at::ones(expand_size)); |
| 120 | + auto scale_nd = ctx->net->addScaleNd( |
| 121 | + *div_out, nvinfer1::ScaleMode::kELEMENTWISE, beta_weights.data, gamma_weights.data, power.data, 1); |
| 122 | + scale_nd->setName((util::node_info(n) + "_scale_nd").c_str()); |
| 123 | + auto scale_nd_out = scale_nd->getOutput(0); |
| 124 | + |
| 125 | + ctx->AssociateValueAndTensor(n->outputs()[0], scale_nd_out); |
| 126 | + return true; |
| 127 | + }}); |
| 128 | + |
| 129 | +} // namespace |
| 130 | +} // namespace impl |
| 131 | +} // namespace converters |
| 132 | +} // namespace conversion |
| 133 | +} // namespace core |
| 134 | +} // namespace trtorch |
0 commit comments