Skip to content

Commit 4f00517

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Fix & cleanup optimized native_layer_norm (#812)
Summary: Pull Request resolved: #812 ghstack-source-id: 203616980 exported-using-ghexport Reviewed By: kirklandsign Differential Revision: D50146421 fbshipit-source-id: e7c084330202b5a957b8ba1ac39805c22080c345
1 parent 7ef721c commit 4f00517

File tree

2 files changed

+92
-97
lines changed

2 files changed

+92
-97
lines changed

kernels/optimized/cpu/op_native_layer_norm.cpp

Lines changed: 91 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <executorch/kernels/optimized/cpu/moments_utils.h>
1414
#include <executorch/kernels/optimized/vec/functional.h>
1515
#include <executorch/kernels/optimized/vec/vec.h>
16+
#include <executorch/kernels/portable/cpu/util/normalization_ops_util.h>
1617

1718
namespace torch {
1819
namespace executor {
@@ -25,148 +26,141 @@ namespace {
2526
template <typename CTYPE>
2627
void layer_norm(
2728
const Tensor& input,
28-
const Tensor& gamma,
29-
const Tensor& beta,
30-
const size_t M,
31-
const size_t N,
29+
IntArrayRef normalized_shape,
30+
const optional<Tensor>& weight,
31+
const optional<Tensor>& bias,
3232
CTYPE eps,
33-
Tensor& output,
33+
Tensor& out,
3434
Tensor& mean,
3535
Tensor& rstd) {
3636
using Vec = executorch::vec::Vectorized<CTYPE>;
3737

38-
const CTYPE* __restrict__ input_data = input.data_ptr<CTYPE>();
39-
const CTYPE* __restrict__ gamma_data = gamma.data_ptr<CTYPE>();
40-
const CTYPE* __restrict__ beta_data = beta.data_ptr<CTYPE>();
41-
CTYPE* __restrict__ output_data = output.data_ptr<CTYPE>();
38+
const size_t dim = input.dim() - normalized_shape.size();
39+
const size_t dim_size = input.size(dim);
40+
41+
const size_t M = getLeadingDims(input, dim);
42+
const size_t N = getTrailingDims(input, dim) * dim_size;
43+
44+
if (M == 0) {
45+
return;
46+
}
47+
48+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
49+
CTYPE* mean_data = mean.mutable_data_ptr<CTYPE>();
50+
CTYPE* rstd_data = rstd.mutable_data_ptr<CTYPE>();
51+
52+
if (N == 0) {
53+
for (int i = 0; i < M; ++i) {
54+
mean_data[i] = static_cast<CTYPE>(0);
55+
rstd_data[i] = static_cast<CTYPE>(NAN);
56+
}
57+
return;
58+
}
59+
60+
const CTYPE* input_data = input.const_data_ptr<CTYPE>();
61+
const CTYPE* gamma_data;
62+
if (weight.has_value()) {
63+
gamma_data = weight.value().const_data_ptr<CTYPE>();
64+
} else {
65+
gamma_data = nullptr;
66+
}
67+
const CTYPE* beta_data;
68+
if (bias.has_value()) {
69+
beta_data = bias.value().const_data_ptr<CTYPE>();
70+
} else {
71+
beta_data = nullptr;
72+
}
4273

4374
const bool gamma_null = gamma_data == nullptr;
4475
const bool beta_null = beta_data == nullptr;
4576

4677
for (size_t i = 0; i < M; ++i) {
4778
const CTYPE* src_ptr = input_data + i * N;
48-
CTYPE* dst_ptr = output_data + i * N;
79+
CTYPE* dst_ptr = out_data + i * N;
4980

5081
CTYPE mean_val;
5182
CTYPE rstd_val;
5283
std::tie(mean_val, rstd_val) = RowwiseMoments(src_ptr, N);
5384
rstd_val = CTYPE(1) / std::sqrt(rstd_val + eps);
5485

5586
const CTYPE scale = rstd_val;
56-
const CTYPE bias = -rstd_val * mean_val;
87+
const CTYPE offset = -rstd_val * mean_val;
5788

5889
if (gamma_null || beta_null) {
5990
for (size_t j = 0; j < N; ++j) {
6091
const CTYPE gamma_v = gamma_null ? CTYPE(1) : gamma_data[j];
6192
const CTYPE beta_v = beta_null ? CTYPE(0) : beta_data[j];
62-
dst_ptr[j] = (src_ptr[j] * scale + bias) * gamma_v + beta_v;
93+
dst_ptr[j] = (src_ptr[j] * scale + offset) * gamma_v + beta_v;
6394
}
6495
} else {
6596
executorch::vec::map3<CTYPE>(
66-
[scale, bias](Vec x, Vec gamma, Vec beta) {
67-
return (x * Vec(scale) + Vec(bias)) * gamma + beta;
97+
[scale, offset](Vec x, Vec gamma, Vec beta) {
98+
return (x * Vec(scale) + Vec(offset)) * gamma + beta;
6899
},
69100
dst_ptr,
70101
src_ptr,
71102
gamma_data,
72103
beta_data,
73104
N);
74105
}
75-
}
76106

77-
// Assign NAN to mean and rstd. They are not used in seen examples.
78-
// Use NAN to make the error more obvious in case they are used.
79-
mean.data_ptr<CTYPE>()[0] = NAN;
80-
rstd.data_ptr<CTYPE>()[0] = NAN;
107+
mean_data[i] = mean_val;
108+
rstd_data[i] = rstd_val;
109+
}
81110
}
82111

83112
} // namespace
84113

85-
// native_layer_norm.out(Tensor input, int[] normalized_shape, Tensor? weight,
86-
// Tensor? bias, float eps, *, Tensor(a!) out, Tensor(b!) mean_out, Tensor(c!)
87-
// rstd_out) -> (Tensor(a!), Tensor(b!), Tensor(c!))
88-
//
89-
// Unlike the ATen implementation of native_layer_norm, mean_out and rstd_out
90-
// are not filled with any meaningful data. Instead, they are set to NAN to
91-
// easily detect if they are being used.
92114
std::tuple<Tensor&, Tensor&, Tensor&> opt_native_layer_norm_out(
93-
RuntimeContext& context,
115+
RuntimeContext& ctx,
94116
const Tensor& input,
95117
IntArrayRef normalized_shape,
96-
const exec_aten::optional<Tensor>& gamma,
97-
const exec_aten::optional<Tensor>& beta,
118+
const exec_aten::optional<Tensor>& weight,
119+
const exec_aten::optional<Tensor>& bias,
98120
double eps,
99121
Tensor& out,
100122
Tensor& mean_out,
101123
Tensor& rstd_out) {
102-
(void)context;
103-
104-
ET_CHECK_MSG(
105-
normalized_shape.size() == 1,
106-
"normalize_shape.size() must be 1 but saw %zd",
107-
normalized_shape.size());
108-
ET_CHECK_MSG(
109-
input.scalar_type() == out.scalar_type(),
110-
"out and input must have the same type.");
111-
ET_CHECK_MSG(
112-
input.dim() == out.dim(),
113-
"out and input must have the same number of dimensions");
114-
ET_CHECK_MSG(
115-
input.scalar_type() == mean_out.scalar_type(),
116-
"mean_out and input must have the same type.");
117-
ET_CHECK_MSG(
118-
input.scalar_type() == rstd_out.scalar_type(),
119-
"rstd_out and input must have the same type.");
120-
121-
if (input.sizes() == out.sizes()) {
122-
ET_CHECK_MSG(
123-
normalized_shape[0] == input.sizes()[input.dim() - 1],
124-
"Normalized shape value must match the size of input.");
125-
} else {
126-
// If we need to resize out to support dynamic input shapes, we can't count
127-
// on normalized_shape matching the shape of the input or output. But we
128-
// don't need to modify normalized_shape because it's not used in this
129-
// function besides some checks
130-
torch::executor::Error err = resize_tensor(out, input.sizes());
131-
ET_CHECK_MSG(
132-
err == torch::executor::Error::Ok,
133-
"Failed to resize out Tensor in opt_native_layer_norm_out");
134-
}
135-
136-
const size_t input_ndim = input.dim();
137-
const size_t normalized_ndim = normalized_shape.size();
138-
139-
const size_t axis = input_ndim - normalized_ndim;
140-
141-
const size_t M = getLeadingDims(input, axis);
142-
const size_t N = getTrailingDims(input, axis - 1);
143-
144-
// helper for generating the cases for different data types
145-
#define LAYER_NORM(ctype, dtype) \
146-
case ScalarType::dtype: \
147-
layer_norm<ctype>( \
148-
input, \
149-
gamma.value(), \
150-
beta.value(), \
151-
M, \
152-
N, \
153-
eps, \
154-
out, \
155-
mean_out, \
156-
rstd_out); \
157-
break;
158-
159-
switch (input.scalar_type()) {
160-
// TODO support bfloat16
161-
ET_FORALL_FLOAT_TYPES(LAYER_NORM)
162-
default:
163-
ET_CHECK_MSG(
164-
false,
165-
"Unhandled dtype %" PRId8,
166-
static_cast<int8_t>(input.scalar_type()));
167-
}
168-
#undef LAYER_NORM
169-
return {out, mean_out, rstd_out};
124+
(void)ctx;
125+
126+
std::tuple<Tensor&, Tensor&, Tensor&> ret_val(out, mean_out, rstd_out);
127+
128+
ET_KERNEL_CHECK(
129+
ctx,
130+
check_layer_norm_args(
131+
input, normalized_shape, weight, bias, out, mean_out, rstd_out),
132+
InvalidArgument,
133+
ret_val);
134+
135+
Tensor::SizesType mean_rstd_sizes[kTensorDimensionLimit];
136+
size_t mean_rstd_ndim = 0;
137+
get_layer_norm_out_target_size(
138+
input, normalized_shape, mean_rstd_sizes, &mean_rstd_ndim);
139+
140+
ET_KERNEL_CHECK(
141+
ctx,
142+
resize_tensor(out, input.sizes()) == Error::Ok,
143+
InvalidArgument,
144+
ret_val);
145+
146+
ET_KERNEL_CHECK(
147+
ctx,
148+
resize_tensor(mean_out, {mean_rstd_sizes, mean_rstd_ndim}) == Error::Ok,
149+
InvalidArgument,
150+
ret_val);
151+
152+
ET_KERNEL_CHECK(
153+
ctx,
154+
resize_tensor(rstd_out, {mean_rstd_sizes, mean_rstd_ndim}) == Error::Ok,
155+
InvalidArgument,
156+
ret_val);
157+
158+
ET_SWITCH_FLOAT_TYPES(input.scalar_type(), ctx, __func__, CTYPE, [&]() {
159+
layer_norm<CTYPE>(
160+
input, normalized_shape, weight, bias, eps, out, mean_out, rstd_out);
161+
});
162+
163+
return ret_val;
170164
}
171165

172166
} // namespace native

kernels/optimized/cpu/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ _OPTIMIZED_ATEN_OPS = (
6161
name = "op_native_layer_norm",
6262
deps = [
6363
":moments_utils",
64+
"//executorch/kernels/portable/cpu/util:normalization_ops_util",
6465
],
6566
),
6667
op_target(name = "op_neg"),

0 commit comments

Comments
 (0)