6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
+ #include < executorch/kernels/portable/cpu/util/normalization_ops_util.h>
9
10
#include < executorch/kernels/portable/cpu/vec_ops.h>
10
11
#include < executorch/runtime/kernel/kernel_includes.h>
11
12
#include < cmath>
@@ -18,27 +19,30 @@ namespace native {
18
19
using Tensor = exec_aten::Tensor;
19
20
20
21
namespace {
22
+
21
23
template <typename CTYPE>
22
24
void layer_norm (
23
25
const Tensor& input,
24
26
const Tensor& weight,
25
27
const Tensor& bias,
26
28
CTYPE eps,
27
- Tensor& output ,
29
+ Tensor& out ,
28
30
Tensor& mean,
29
31
Tensor& rstd) {
30
32
const CTYPE* input_data = input.const_data_ptr <CTYPE>();
31
- CTYPE* output_data = output.mutable_data_ptr <CTYPE>();
32
33
const CTYPE* weight_data = weight.const_data_ptr <CTYPE>();
33
34
const CTYPE* bias_data = bias.const_data_ptr <CTYPE>();
35
+ CTYPE* out_data = out.mutable_data_ptr <CTYPE>();
36
+ CTYPE* mean_data = mean.mutable_data_ptr <CTYPE>();
37
+ CTYPE* rstd_data = rstd.mutable_data_ptr <CTYPE>();
34
38
35
39
size_t dim = input.size (input.dim () - 1 );
36
40
37
41
size_t leading_dim = getLeadingDims (input, input.dim () - 1 );
38
42
39
43
for (int i = 0 ; i < leading_dim; ++i) {
40
44
const CTYPE* x = input_data + i * dim;
41
- CTYPE* y = output_data + i * dim;
45
+ CTYPE* y = out_data + i * dim;
42
46
43
47
// compute E[X] and Var[x] = E[x^2] - E[x]^2
44
48
CTYPE sum = reduce_add (x, dim);
@@ -51,13 +55,12 @@ void layer_norm(
51
55
for (int j = 0 ; j < dim; ++j) {
52
56
y[j] = (x[j] - mean_value) / std * weight_data[j] + bias_data[j];
53
57
}
54
- }
55
58
56
- // Assign NAN to mean and rstd. They are not used in seen examples.
57
- // Use NAN to make the error more obvious in case they are used.
58
- mean.mutable_data_ptr <CTYPE>()[0 ] = NAN;
59
- rstd.mutable_data_ptr <CTYPE>()[0 ] = NAN;
59
+ mean_data[i] = mean_value;
60
+ rstd_data[i] = 1.0 / std;
61
+ }
60
62
}
63
+
61
64
} // namespace
62
65
63
66
// native_layer_norm.out(Tensor input, int[] normalized_shape, Tensor? weight,
@@ -75,54 +78,39 @@ std::tuple<Tensor&, Tensor&, Tensor&> native_layer_norm_out(
75
78
Tensor& out,
76
79
Tensor& mean_out,
77
80
Tensor& rstd_out) {
78
- ET_CHECK_MSG (
79
- normalized_shape.size () == 1 ,
80
- " normalize_shape.size() must be 1 but saw %zd" ,
81
- normalized_shape.size ());
82
- ET_CHECK_MSG (weight.has_value (), " Missing weight tensor" );
83
- ET_CHECK_MSG (
84
- input.scalar_type () == out.scalar_type (),
85
- " out and input must have the same type." );
86
- ET_CHECK_MSG (
87
- input.dim () == out.dim (),
88
- " out and input must have the same number of dimensions" );
89
- ET_CHECK_MSG (
90
- input.scalar_type () == mean_out.scalar_type (),
91
- " mean_out and input must have the same type." );
92
- ET_CHECK_MSG (
93
- input.scalar_type () == rstd_out.scalar_type (),
94
- " rstd_out and input must have the same type." );
81
+ std::tuple<Tensor&, Tensor&, Tensor&> ret_val (out, mean_out, rstd_out);
82
+
83
+ ET_KERNEL_CHECK (
84
+ ctx,
85
+ check_layer_norm_args (
86
+ input, normalized_shape, weight, bias, out, mean_out, rstd_out),
87
+ InvalidArgument,
88
+ ret_val);
95
89
96
90
if (input.sizes () == out.sizes ()) {
97
- ET_CHECK_MSG (
91
+ ET_KERNEL_CHECK (
92
+ ctx,
98
93
normalized_shape[0 ] == input.sizes ()[input.dim () - 1 ],
99
- " Normalized shape value must match the size of input." );
94
+ InvalidArgument,
95
+ ret_val);
100
96
} else {
101
97
// If we need to resize out to support dynamic input shapes, we can't count
102
98
// on normalized_shape matching the shape of the input or output. But we
103
99
// don't need to modify normalized_shape because it's not used in this
104
100
// function besides some checks
105
- torch::executor::Error err = resize_tensor (out, input.sizes ());
106
- ET_CHECK_MSG (
107
- err == torch::executor::Error::Ok,
108
- " Failed to resize out Tensor in native_layer_norm_out" );
101
+ ET_KERNEL_CHECK (
102
+ ctx,
103
+ resize_tensor (out, input.sizes ()) == Error::Ok,
104
+ InvalidArgument,
105
+ ret_val);
109
106
}
110
107
111
- // helper for generating the cases for different data types
112
- #define LAYER_NORM (ctype, dtype ) \
113
- case ScalarType::dtype: \
114
- layer_norm<ctype>( \
115
- input, weight.value (), bias.value (), eps, out, mean_out, rstd_out); \
116
- break ;
117
-
118
- switch (input.scalar_type ()) {
119
- // TODO support bfloat16
120
- ET_FORALL_FLOAT_TYPES (LAYER_NORM)
121
- default :
122
- ET_CHECK_MSG (false , " Unhandled dtype %hhd" , input.scalar_type ());
123
- }
124
- #undef LAYER_NORM
125
- return {out, mean_out, rstd_out};
108
+ ET_SWITCH_FLOAT_TYPES (input.scalar_type (), ctx, __func__, CTYPE, [&]() {
109
+ layer_norm<CTYPE>(
110
+ input, weight.value (), bias.value (), eps, out, mean_out, rstd_out);
111
+ });
112
+
113
+ return ret_val;
126
114
}
127
115
128
116
} // namespace native
0 commit comments