13
13
#include < executorch/kernels/optimized/cpu/moments_utils.h>
14
14
#include < executorch/kernels/optimized/vec/functional.h>
15
15
#include < executorch/kernels/optimized/vec/vec.h>
16
+ #include < executorch/kernels/portable/cpu/util/normalization_ops_util.h>
16
17
17
18
namespace torch {
18
19
namespace executor {
@@ -25,148 +26,141 @@ namespace {
25
26
template <typename CTYPE>
26
27
void layer_norm (
27
28
const Tensor& input,
28
- const Tensor& gamma,
29
- const Tensor& beta,
30
- const size_t M,
31
- const size_t N,
29
+ IntArrayRef normalized_shape,
30
+ const optional<Tensor>& weight,
31
+ const optional<Tensor>& bias,
32
32
CTYPE eps,
33
- Tensor& output ,
33
+ Tensor& out ,
34
34
Tensor& mean,
35
35
Tensor& rstd) {
36
36
using Vec = executorch::vec::Vectorized<CTYPE>;
37
37
38
- const CTYPE* __restrict__ input_data = input.data_ptr <CTYPE>();
39
- const CTYPE* __restrict__ gamma_data = gamma.data_ptr <CTYPE>();
40
- const CTYPE* __restrict__ beta_data = beta.data_ptr <CTYPE>();
41
- CTYPE* __restrict__ output_data = output.data_ptr <CTYPE>();
38
+ const size_t dim = input.dim () - normalized_shape.size ();
39
+ const size_t dim_size = input.size (dim);
40
+
41
+ const size_t M = getLeadingDims (input, dim);
42
+ const size_t N = getTrailingDims (input, dim) * dim_size;
43
+
44
+ if (M == 0 ) {
45
+ return ;
46
+ }
47
+
48
+ CTYPE* out_data = out.mutable_data_ptr <CTYPE>();
49
+ CTYPE* mean_data = mean.mutable_data_ptr <CTYPE>();
50
+ CTYPE* rstd_data = rstd.mutable_data_ptr <CTYPE>();
51
+
52
+ if (N == 0 ) {
53
+ for (int i = 0 ; i < M; ++i) {
54
+ mean_data[i] = static_cast <CTYPE>(0 );
55
+ rstd_data[i] = static_cast <CTYPE>(NAN);
56
+ }
57
+ return ;
58
+ }
59
+
60
+ const CTYPE* input_data = input.const_data_ptr <CTYPE>();
61
+ const CTYPE* gamma_data;
62
+ if (weight.has_value ()) {
63
+ gamma_data = weight.value ().const_data_ptr <CTYPE>();
64
+ } else {
65
+ gamma_data = nullptr ;
66
+ }
67
+ const CTYPE* beta_data;
68
+ if (bias.has_value ()) {
69
+ beta_data = bias.value ().const_data_ptr <CTYPE>();
70
+ } else {
71
+ beta_data = nullptr ;
72
+ }
42
73
43
74
const bool gamma_null = gamma_data == nullptr ;
44
75
const bool beta_null = beta_data == nullptr ;
45
76
46
77
for (size_t i = 0 ; i < M; ++i) {
47
78
const CTYPE* src_ptr = input_data + i * N;
48
- CTYPE* dst_ptr = output_data + i * N;
79
+ CTYPE* dst_ptr = out_data + i * N;
49
80
50
81
CTYPE mean_val;
51
82
CTYPE rstd_val;
52
83
std::tie (mean_val, rstd_val) = RowwiseMoments (src_ptr, N);
53
84
rstd_val = CTYPE (1 ) / std::sqrt (rstd_val + eps);
54
85
55
86
const CTYPE scale = rstd_val;
56
- const CTYPE bias = -rstd_val * mean_val;
87
+ const CTYPE offset = -rstd_val * mean_val;
57
88
58
89
if (gamma_null || beta_null) {
59
90
for (size_t j = 0 ; j < N; ++j) {
60
91
const CTYPE gamma_v = gamma_null ? CTYPE (1 ) : gamma_data[j];
61
92
const CTYPE beta_v = beta_null ? CTYPE (0 ) : beta_data[j];
62
- dst_ptr[j] = (src_ptr[j] * scale + bias ) * gamma_v + beta_v;
93
+ dst_ptr[j] = (src_ptr[j] * scale + offset ) * gamma_v + beta_v;
63
94
}
64
95
} else {
65
96
executorch::vec::map3<CTYPE>(
66
- [scale, bias ](Vec x, Vec gamma, Vec beta) {
67
- return (x * Vec (scale) + Vec (bias )) * gamma + beta;
97
+ [scale, offset ](Vec x, Vec gamma, Vec beta) {
98
+ return (x * Vec (scale) + Vec (offset )) * gamma + beta;
68
99
},
69
100
dst_ptr,
70
101
src_ptr,
71
102
gamma_data,
72
103
beta_data,
73
104
N);
74
105
}
75
- }
76
106
77
- // Assign NAN to mean and rstd. They are not used in seen examples.
78
- // Use NAN to make the error more obvious in case they are used.
79
- mean.data_ptr <CTYPE>()[0 ] = NAN;
80
- rstd.data_ptr <CTYPE>()[0 ] = NAN;
107
+ mean_data[i] = mean_val;
108
+ rstd_data[i] = rstd_val;
109
+ }
81
110
}
82
111
83
112
} // namespace
84
113
85
- // native_layer_norm.out(Tensor input, int[] normalized_shape, Tensor? weight,
86
- // Tensor? bias, float eps, *, Tensor(a!) out, Tensor(b!) mean_out, Tensor(c!)
87
- // rstd_out) -> (Tensor(a!), Tensor(b!), Tensor(c!))
88
- //
89
- // Unlike the ATen implementation of native_layer_norm, mean_out and rstd_out
90
- // are not filled with any meaningful data. Instead, they are set to NAN to
91
- // easily detect if they are being used.
92
114
std::tuple<Tensor&, Tensor&, Tensor&> opt_native_layer_norm_out (
93
- RuntimeContext& context ,
115
+ RuntimeContext& ctx ,
94
116
const Tensor& input,
95
117
IntArrayRef normalized_shape,
96
- const exec_aten::optional<Tensor>& gamma ,
97
- const exec_aten::optional<Tensor>& beta ,
118
+ const exec_aten::optional<Tensor>& weight ,
119
+ const exec_aten::optional<Tensor>& bias ,
98
120
double eps,
99
121
Tensor& out,
100
122
Tensor& mean_out,
101
123
Tensor& rstd_out) {
102
- (void )context;
103
-
104
- ET_CHECK_MSG (
105
- normalized_shape.size () == 1 ,
106
- " normalize_shape.size() must be 1 but saw %zd" ,
107
- normalized_shape.size ());
108
- ET_CHECK_MSG (
109
- input.scalar_type () == out.scalar_type (),
110
- " out and input must have the same type." );
111
- ET_CHECK_MSG (
112
- input.dim () == out.dim (),
113
- " out and input must have the same number of dimensions" );
114
- ET_CHECK_MSG (
115
- input.scalar_type () == mean_out.scalar_type (),
116
- " mean_out and input must have the same type." );
117
- ET_CHECK_MSG (
118
- input.scalar_type () == rstd_out.scalar_type (),
119
- " rstd_out and input must have the same type." );
120
-
121
- if (input.sizes () == out.sizes ()) {
122
- ET_CHECK_MSG (
123
- normalized_shape[0 ] == input.sizes ()[input.dim () - 1 ],
124
- " Normalized shape value must match the size of input." );
125
- } else {
126
- // If we need to resize out to support dynamic input shapes, we can't count
127
- // on normalized_shape matching the shape of the input or output. But we
128
- // don't need to modify normalized_shape because it's not used in this
129
- // function besides some checks
130
- torch::executor::Error err = resize_tensor (out, input.sizes ());
131
- ET_CHECK_MSG (
132
- err == torch::executor::Error::Ok,
133
- " Failed to resize out Tensor in opt_native_layer_norm_out" );
134
- }
135
-
136
- const size_t input_ndim = input.dim ();
137
- const size_t normalized_ndim = normalized_shape.size ();
138
-
139
- const size_t axis = input_ndim - normalized_ndim;
140
-
141
- const size_t M = getLeadingDims (input, axis);
142
- const size_t N = getTrailingDims (input, axis - 1 );
143
-
144
- // helper for generating the cases for different data types
145
- #define LAYER_NORM (ctype, dtype ) \
146
- case ScalarType::dtype: \
147
- layer_norm<ctype>( \
148
- input, \
149
- gamma.value (), \
150
- beta.value (), \
151
- M, \
152
- N, \
153
- eps, \
154
- out, \
155
- mean_out, \
156
- rstd_out); \
157
- break ;
158
-
159
- switch (input.scalar_type ()) {
160
- // TODO support bfloat16
161
- ET_FORALL_FLOAT_TYPES (LAYER_NORM)
162
- default :
163
- ET_CHECK_MSG (
164
- false ,
165
- " Unhandled dtype %" PRId8,
166
- static_cast <int8_t >(input.scalar_type ()));
167
- }
168
- #undef LAYER_NORM
169
- return {out, mean_out, rstd_out};
124
+ (void )ctx;
125
+
126
+ std::tuple<Tensor&, Tensor&, Tensor&> ret_val (out, mean_out, rstd_out);
127
+
128
+ ET_KERNEL_CHECK (
129
+ ctx,
130
+ check_layer_norm_args (
131
+ input, normalized_shape, weight, bias, out, mean_out, rstd_out),
132
+ InvalidArgument,
133
+ ret_val);
134
+
135
+ Tensor::SizesType mean_rstd_sizes[kTensorDimensionLimit ];
136
+ size_t mean_rstd_ndim = 0 ;
137
+ get_layer_norm_out_target_size (
138
+ input, normalized_shape, mean_rstd_sizes, &mean_rstd_ndim);
139
+
140
+ ET_KERNEL_CHECK (
141
+ ctx,
142
+ resize_tensor (out, input.sizes ()) == Error::Ok,
143
+ InvalidArgument,
144
+ ret_val);
145
+
146
+ ET_KERNEL_CHECK (
147
+ ctx,
148
+ resize_tensor (mean_out, {mean_rstd_sizes, mean_rstd_ndim}) == Error::Ok,
149
+ InvalidArgument,
150
+ ret_val);
151
+
152
+ ET_KERNEL_CHECK (
153
+ ctx,
154
+ resize_tensor (rstd_out, {mean_rstd_sizes, mean_rstd_ndim}) == Error::Ok,
155
+ InvalidArgument,
156
+ ret_val);
157
+
158
+ ET_SWITCH_FLOAT_TYPES (input.scalar_type (), ctx, __func__, CTYPE, [&]() {
159
+ layer_norm<CTYPE>(
160
+ input, normalized_shape, weight, bias, eps, out, mean_out, rstd_out);
161
+ });
162
+
163
+ return ret_val;
170
164
}
171
165
172
166
} // namespace native
0 commit comments