Skip to content

[ET][Optimized] Fix & cleanup optimized native_layer_norm #812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 91 additions & 97 deletions kernels/optimized/cpu/op_native_layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <executorch/kernels/optimized/cpu/moments_utils.h>
#include <executorch/kernels/optimized/vec/functional.h>
#include <executorch/kernels/optimized/vec/vec.h>
#include <executorch/kernels/portable/cpu/util/normalization_ops_util.h>

namespace torch {
namespace executor {
Expand All @@ -25,148 +26,141 @@ namespace {
template <typename CTYPE>
void layer_norm(
const Tensor& input,
const Tensor& gamma,
const Tensor& beta,
const size_t M,
const size_t N,
IntArrayRef normalized_shape,
const optional<Tensor>& weight,
const optional<Tensor>& bias,
CTYPE eps,
Tensor& output,
Tensor& out,
Tensor& mean,
Tensor& rstd) {
using Vec = executorch::vec::Vectorized<CTYPE>;

const CTYPE* __restrict__ input_data = input.data_ptr<CTYPE>();
const CTYPE* __restrict__ gamma_data = gamma.data_ptr<CTYPE>();
const CTYPE* __restrict__ beta_data = beta.data_ptr<CTYPE>();
CTYPE* __restrict__ output_data = output.data_ptr<CTYPE>();
const size_t dim = input.dim() - normalized_shape.size();
const size_t dim_size = input.size(dim);

const size_t M = getLeadingDims(input, dim);
const size_t N = getTrailingDims(input, dim) * dim_size;

if (M == 0) {
return;
}

CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
CTYPE* mean_data = mean.mutable_data_ptr<CTYPE>();
CTYPE* rstd_data = rstd.mutable_data_ptr<CTYPE>();

if (N == 0) {
for (int i = 0; i < M; ++i) {
mean_data[i] = static_cast<CTYPE>(0);
rstd_data[i] = static_cast<CTYPE>(NAN);
}
return;
}

const CTYPE* input_data = input.const_data_ptr<CTYPE>();
const CTYPE* gamma_data;
if (weight.has_value()) {
gamma_data = weight.value().const_data_ptr<CTYPE>();
} else {
gamma_data = nullptr;
}
const CTYPE* beta_data;
if (bias.has_value()) {
beta_data = bias.value().const_data_ptr<CTYPE>();
} else {
beta_data = nullptr;
}

const bool gamma_null = gamma_data == nullptr;
const bool beta_null = beta_data == nullptr;

for (size_t i = 0; i < M; ++i) {
const CTYPE* src_ptr = input_data + i * N;
CTYPE* dst_ptr = output_data + i * N;
CTYPE* dst_ptr = out_data + i * N;

CTYPE mean_val;
CTYPE rstd_val;
std::tie(mean_val, rstd_val) = RowwiseMoments(src_ptr, N);
rstd_val = CTYPE(1) / std::sqrt(rstd_val + eps);

const CTYPE scale = rstd_val;
const CTYPE bias = -rstd_val * mean_val;
const CTYPE offset = -rstd_val * mean_val;

if (gamma_null || beta_null) {
for (size_t j = 0; j < N; ++j) {
const CTYPE gamma_v = gamma_null ? CTYPE(1) : gamma_data[j];
const CTYPE beta_v = beta_null ? CTYPE(0) : beta_data[j];
dst_ptr[j] = (src_ptr[j] * scale + bias) * gamma_v + beta_v;
dst_ptr[j] = (src_ptr[j] * scale + offset) * gamma_v + beta_v;
}
} else {
executorch::vec::map3<CTYPE>(
[scale, bias](Vec x, Vec gamma, Vec beta) {
return (x * Vec(scale) + Vec(bias)) * gamma + beta;
[scale, offset](Vec x, Vec gamma, Vec beta) {
return (x * Vec(scale) + Vec(offset)) * gamma + beta;
},
dst_ptr,
src_ptr,
gamma_data,
beta_data,
N);
}
}

// Assign NAN to mean and rstd. They are not used in seen examples.
// Use NAN to make the error more obvious in case they are used.
mean.data_ptr<CTYPE>()[0] = NAN;
rstd.data_ptr<CTYPE>()[0] = NAN;
mean_data[i] = mean_val;
rstd_data[i] = rstd_val;
}
}

} // namespace

// native_layer_norm.out(Tensor input, int[] normalized_shape, Tensor? weight,
// Tensor? bias, float eps, *, Tensor(a!) out, Tensor(b!) mean_out, Tensor(c!)
// rstd_out) -> (Tensor(a!), Tensor(b!), Tensor(c!))
//
// Unlike the ATen implementation of native_layer_norm, mean_out and rstd_out
// are not filled with any meaningful data. Instead, they are set to NAN to
// easily detect if they are being used.
std::tuple<Tensor&, Tensor&, Tensor&> opt_native_layer_norm_out(
RuntimeContext& context,
RuntimeContext& ctx,
const Tensor& input,
IntArrayRef normalized_shape,
const exec_aten::optional<Tensor>& gamma,
const exec_aten::optional<Tensor>& beta,
const exec_aten::optional<Tensor>& weight,
const exec_aten::optional<Tensor>& bias,
double eps,
Tensor& out,
Tensor& mean_out,
Tensor& rstd_out) {
(void)context;

ET_CHECK_MSG(
normalized_shape.size() == 1,
"normalize_shape.size() must be 1 but saw %zd",
normalized_shape.size());
ET_CHECK_MSG(
input.scalar_type() == out.scalar_type(),
"out and input must have the same type.");
ET_CHECK_MSG(
input.dim() == out.dim(),
"out and input must have the same number of dimensions");
ET_CHECK_MSG(
input.scalar_type() == mean_out.scalar_type(),
"mean_out and input must have the same type.");
ET_CHECK_MSG(
input.scalar_type() == rstd_out.scalar_type(),
"rstd_out and input must have the same type.");

if (input.sizes() == out.sizes()) {
ET_CHECK_MSG(
normalized_shape[0] == input.sizes()[input.dim() - 1],
"Normalized shape value must match the size of input.");
} else {
// If we need to resize out to support dynamic input shapes, we can't count
// on normalized_shape matching the shape of the input or output. But we
// don't need to modify normalized_shape because it's not used in this
// function besides some checks
torch::executor::Error err = resize_tensor(out, input.sizes());
ET_CHECK_MSG(
err == torch::executor::Error::Ok,
"Failed to resize out Tensor in opt_native_layer_norm_out");
}

const size_t input_ndim = input.dim();
const size_t normalized_ndim = normalized_shape.size();

const size_t axis = input_ndim - normalized_ndim;

const size_t M = getLeadingDims(input, axis);
const size_t N = getTrailingDims(input, axis - 1);

// helper for generating the cases for different data types
#define LAYER_NORM(ctype, dtype) \
case ScalarType::dtype: \
layer_norm<ctype>( \
input, \
gamma.value(), \
beta.value(), \
M, \
N, \
eps, \
out, \
mean_out, \
rstd_out); \
break;

switch (input.scalar_type()) {
// TODO support bfloat16
ET_FORALL_FLOAT_TYPES(LAYER_NORM)
default:
ET_CHECK_MSG(
false,
"Unhandled dtype %" PRId8,
static_cast<int8_t>(input.scalar_type()));
}
#undef LAYER_NORM
return {out, mean_out, rstd_out};
(void)ctx;

std::tuple<Tensor&, Tensor&, Tensor&> ret_val(out, mean_out, rstd_out);

ET_KERNEL_CHECK(
ctx,
check_layer_norm_args(
input, normalized_shape, weight, bias, out, mean_out, rstd_out),
InvalidArgument,
ret_val);

Tensor::SizesType mean_rstd_sizes[kTensorDimensionLimit];
size_t mean_rstd_ndim = 0;
get_layer_norm_out_target_size(
input, normalized_shape, mean_rstd_sizes, &mean_rstd_ndim);

ET_KERNEL_CHECK(
ctx,
resize_tensor(out, input.sizes()) == Error::Ok,
InvalidArgument,
ret_val);

ET_KERNEL_CHECK(
ctx,
resize_tensor(mean_out, {mean_rstd_sizes, mean_rstd_ndim}) == Error::Ok,
InvalidArgument,
ret_val);

ET_KERNEL_CHECK(
ctx,
resize_tensor(rstd_out, {mean_rstd_sizes, mean_rstd_ndim}) == Error::Ok,
InvalidArgument,
ret_val);

ET_SWITCH_FLOAT_TYPES(input.scalar_type(), ctx, __func__, CTYPE, [&]() {
layer_norm<CTYPE>(
input, normalized_shape, weight, bias, eps, out, mean_out, rstd_out);
});

return ret_val;
}

} // namespace native
Expand Down
1 change: 1 addition & 0 deletions kernels/optimized/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ _OPTIMIZED_ATEN_OPS = (
name = "op_native_layer_norm",
deps = [
":moments_utils",
"//executorch/kernels/portable/cpu/util:normalization_ops_util",
],
),
op_target(name = "op_neg"),
Expand Down