Skip to content

Commit b35b665

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Dtype compliance: native_layer_norm
Reviewed By: SS-JIA Differential Revision: D48371008 fbshipit-source-id: cbfb97ce1d29931fd27eec27a24a4eab8857dc50
1 parent 8aa5db4 commit b35b665

File tree

4 files changed

+67
-46
lines changed

4 files changed

+67
-46
lines changed

kernels/portable/cpu/op_native_layer_norm.cpp

Lines changed: 34 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/kernels/portable/cpu/util/normalization_ops_util.h>
910
#include <executorch/kernels/portable/cpu/vec_ops.h>
1011
#include <executorch/runtime/kernel/kernel_includes.h>
1112
#include <cmath>
@@ -18,27 +19,30 @@ namespace native {
1819
using Tensor = exec_aten::Tensor;
1920

2021
namespace {
22+
2123
template <typename CTYPE>
2224
void layer_norm(
2325
const Tensor& input,
2426
const Tensor& weight,
2527
const Tensor& bias,
2628
CTYPE eps,
27-
Tensor& output,
29+
Tensor& out,
2830
Tensor& mean,
2931
Tensor& rstd) {
3032
const CTYPE* input_data = input.const_data_ptr<CTYPE>();
31-
CTYPE* output_data = output.mutable_data_ptr<CTYPE>();
3233
const CTYPE* weight_data = weight.const_data_ptr<CTYPE>();
3334
const CTYPE* bias_data = bias.const_data_ptr<CTYPE>();
35+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
36+
CTYPE* mean_data = mean.mutable_data_ptr<CTYPE>();
37+
CTYPE* rstd_data = rstd.mutable_data_ptr<CTYPE>();
3438

3539
size_t dim = input.size(input.dim() - 1);
3640

3741
size_t leading_dim = getLeadingDims(input, input.dim() - 1);
3842

3943
for (int i = 0; i < leading_dim; ++i) {
4044
const CTYPE* x = input_data + i * dim;
41-
CTYPE* y = output_data + i * dim;
45+
CTYPE* y = out_data + i * dim;
4246

4347
// compute E[X] and Var[x] = E[x^2] - E[x]^2
4448
CTYPE sum = reduce_add(x, dim);
@@ -51,13 +55,12 @@ void layer_norm(
5155
for (int j = 0; j < dim; ++j) {
5256
y[j] = (x[j] - mean_value) / std * weight_data[j] + bias_data[j];
5357
}
54-
}
5558

56-
// Assign NAN to mean and rstd. They are not used in seen examples.
57-
// Use NAN to make the error more obvious in case they are used.
58-
mean.mutable_data_ptr<CTYPE>()[0] = NAN;
59-
rstd.mutable_data_ptr<CTYPE>()[0] = NAN;
59+
mean_data[i] = mean_value;
60+
rstd_data[i] = 1.0 / std;
61+
}
6062
}
63+
6164
} // namespace
6265

6366
// native_layer_norm.out(Tensor input, int[] normalized_shape, Tensor? weight,
@@ -75,54 +78,39 @@ std::tuple<Tensor&, Tensor&, Tensor&> native_layer_norm_out(
7578
Tensor& out,
7679
Tensor& mean_out,
7780
Tensor& rstd_out) {
78-
ET_CHECK_MSG(
79-
normalized_shape.size() == 1,
80-
"normalize_shape.size() must be 1 but saw %zd",
81-
normalized_shape.size());
82-
ET_CHECK_MSG(weight.has_value(), "Missing weight tensor");
83-
ET_CHECK_MSG(
84-
input.scalar_type() == out.scalar_type(),
85-
"out and input must have the same type.");
86-
ET_CHECK_MSG(
87-
input.dim() == out.dim(),
88-
"out and input must have the same number of dimensions");
89-
ET_CHECK_MSG(
90-
input.scalar_type() == mean_out.scalar_type(),
91-
"mean_out and input must have the same type.");
92-
ET_CHECK_MSG(
93-
input.scalar_type() == rstd_out.scalar_type(),
94-
"rstd_out and input must have the same type.");
81+
std::tuple<Tensor&, Tensor&, Tensor&> ret_val(out, mean_out, rstd_out);
82+
83+
ET_KERNEL_CHECK(
84+
ctx,
85+
check_layer_norm_args(
86+
input, normalized_shape, weight, bias, out, mean_out, rstd_out),
87+
InvalidArgument,
88+
ret_val);
9589

9690
if (input.sizes() == out.sizes()) {
97-
ET_CHECK_MSG(
91+
ET_KERNEL_CHECK(
92+
ctx,
9893
normalized_shape[0] == input.sizes()[input.dim() - 1],
99-
"Normalized shape value must match the size of input.");
94+
InvalidArgument,
95+
ret_val);
10096
} else {
10197
// If we need to resize out to support dynamic input shapes, we can't count
10298
// on normalized_shape matching the shape of the input or output. But we
10399
// don't need to modify normalized_shape because it's not used in this
104100
// function besides some checks
105-
torch::executor::Error err = resize_tensor(out, input.sizes());
106-
ET_CHECK_MSG(
107-
err == torch::executor::Error::Ok,
108-
"Failed to resize out Tensor in native_layer_norm_out");
101+
ET_KERNEL_CHECK(
102+
ctx,
103+
resize_tensor(out, input.sizes()) == Error::Ok,
104+
InvalidArgument,
105+
ret_val);
109106
}
110107

111-
// helper for generating the cases for different data types
112-
#define LAYER_NORM(ctype, dtype) \
113-
case ScalarType::dtype: \
114-
layer_norm<ctype>( \
115-
input, weight.value(), bias.value(), eps, out, mean_out, rstd_out); \
116-
break;
117-
118-
switch (input.scalar_type()) {
119-
// TODO support bfloat16
120-
ET_FORALL_FLOAT_TYPES(LAYER_NORM)
121-
default:
122-
ET_CHECK_MSG(false, "Unhandled dtype %hhd", input.scalar_type());
123-
}
124-
#undef LAYER_NORM
125-
return {out, mean_out, rstd_out};
108+
ET_SWITCH_FLOAT_TYPES(input.scalar_type(), ctx, __func__, CTYPE, [&]() {
109+
layer_norm<CTYPE>(
110+
input, weight.value(), bias.value(), eps, out, mean_out, rstd_out);
111+
});
112+
113+
return ret_val;
126114
}
127115

128116
} // namespace native

kernels/portable/cpu/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,7 @@ _ATEN_OPS = (
527527
name = "op_native_layer_norm",
528528
deps = [
529529
":vec_ops",
530+
"//executorch/kernels/portable/cpu/util:normalization_ops_util",
530531
],
531532
),
532533
op_target(

kernels/portable/cpu/util/normalization_ops_util.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,28 @@ bool check_batch_norm_args(
5555
return true;
5656
}
5757

58+
bool check_layer_norm_args(
59+
const Tensor& input,
60+
IntArrayRef normalized_shape,
61+
const exec_aten::optional<Tensor>& weight,
62+
const exec_aten::optional<Tensor>& bias,
63+
Tensor& out,
64+
Tensor& mean_out,
65+
Tensor& rstd_out) {
66+
ET_LOG_AND_RETURN_IF_FALSE(normalized_shape.size() == 1);
67+
ET_LOG_AND_RETURN_IF_FALSE(weight.has_value());
68+
if (weight.has_value()) {
69+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, weight.value()));
70+
}
71+
if (bias.has_value()) {
72+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, bias.value()));
73+
}
74+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, out));
75+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, mean_out));
76+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, rstd_out));
77+
ET_LOG_AND_RETURN_IF_FALSE(input.dim() == out.dim());
78+
return true;
79+
}
80+
5881
} // namespace executor
5982
} // namespace torch

kernels/portable/cpu/util/normalization_ops_util.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,14 @@ bool check_batch_norm_args(
2323
double eps,
2424
Tensor& out);
2525

26+
bool check_layer_norm_args(
27+
const Tensor& input,
28+
IntArrayRef normalized_shape,
29+
const exec_aten::optional<Tensor>& weight,
30+
const exec_aten::optional<Tensor>& bias,
31+
Tensor& out,
32+
Tensor& mean_out,
33+
Tensor& rstd_out);
34+
2635
} // namespace executor
2736
} // namespace torch

0 commit comments

Comments
 (0)