@@ -23,37 +23,65 @@ namespace {
23
23
template <typename CTYPE>
24
24
void layer_norm (
25
25
const Tensor& input,
26
- const Tensor& weight,
27
- const Tensor& bias,
26
+ IntArrayRef normalized_shape,
27
+ const optional<Tensor>& weight,
28
+ const optional<Tensor>& bias,
28
29
CTYPE eps,
29
30
Tensor& out,
30
31
Tensor& mean,
31
32
Tensor& rstd) {
32
- const CTYPE* input_data = input.const_data_ptr <CTYPE>();
33
- const CTYPE* weight_data = weight.const_data_ptr <CTYPE>();
34
- const CTYPE* bias_data = bias.const_data_ptr <CTYPE>();
33
+ size_t dim = input.dim () - normalized_shape.size ();
34
+ size_t dim_size = input.size (dim);
35
+
36
+ size_t leading = getLeadingDims (input, dim);
37
+ size_t normalized = getTrailingDims (input, dim) * dim_size;
38
+
39
+ if (leading == 0 ) {
40
+ return ;
41
+ }
42
+
35
43
CTYPE* out_data = out.mutable_data_ptr <CTYPE>();
36
44
CTYPE* mean_data = mean.mutable_data_ptr <CTYPE>();
37
45
CTYPE* rstd_data = rstd.mutable_data_ptr <CTYPE>();
38
46
39
- size_t dim = input.size (input.dim () - 1 );
47
+ if (normalized == 0 ) {
48
+ for (int i = 0 ; i < leading; ++i) {
49
+ mean_data[i] = static_cast <CTYPE>(0 );
50
+ rstd_data[i] = static_cast <CTYPE>(NAN);
51
+ }
52
+ return ;
53
+ }
40
54
41
- size_t leading_dim = getLeadingDims (input, input.dim () - 1 );
55
+ const CTYPE* input_data = input.const_data_ptr <CTYPE>();
56
+ const CTYPE* weight_data;
57
+ if (weight.has_value ()) {
58
+ weight_data = weight.value ().const_data_ptr <CTYPE>();
59
+ } else {
60
+ weight_data = nullptr ;
61
+ }
62
+ const CTYPE* bias_data;
63
+ if (bias.has_value ()) {
64
+ bias_data = bias.value ().const_data_ptr <CTYPE>();
65
+ } else {
66
+ bias_data = nullptr ;
67
+ }
42
68
43
- for (int i = 0 ; i < leading_dim ; ++i) {
44
- const CTYPE* x = input_data + i * dim ;
45
- CTYPE* y = out_data + i * dim ;
69
+ for (int i = 0 ; i < leading ; ++i) {
70
+ const CTYPE* x = input_data + i * normalized ;
71
+ CTYPE* y = out_data + i * normalized ;
46
72
47
73
// compute E[X] and Var[x] = E[x^2] - E[x]^2
48
- CTYPE sum = reduce_add (x, dim );
49
- CTYPE sq_sum = vec_powerf (x, dim );
50
- CTYPE mean_value = sum / dim ;
51
- CTYPE variance = sq_sum / dim - mean_value * mean_value;
74
+ CTYPE sum = reduce_add (x, normalized );
75
+ CTYPE sq_sum = vec_powerf (x, normalized );
76
+ CTYPE mean_value = sum / normalized ;
77
+ CTYPE variance = sq_sum / normalized - mean_value * mean_value;
52
78
CTYPE std = std::sqrt (variance + eps);
53
79
54
80
// Calculate the elements of output
55
- for (int j = 0 ; j < dim; ++j) {
56
- y[j] = (x[j] - mean_value) / std * weight_data[j] + bias_data[j];
81
+ for (int j = 0 ; j < normalized; ++j) {
82
+ CTYPE w = weight_data ? weight_data[j] : static_cast <CTYPE>(1 );
83
+ CTYPE b = bias_data ? bias_data[j] : static_cast <CTYPE>(0 );
84
+ y[j] = (x[j] - mean_value) / std * w + b;
57
85
}
58
86
59
87
mean_data[i] = mean_value;
@@ -87,27 +115,32 @@ std::tuple<Tensor&, Tensor&, Tensor&> native_layer_norm_out(
87
115
InvalidArgument,
88
116
ret_val);
89
117
90
- if (input.sizes () == out.sizes ()) {
91
- ET_KERNEL_CHECK (
92
- ctx,
93
- normalized_shape[0 ] == input.sizes ()[input.dim () - 1 ],
94
- InvalidArgument,
95
- ret_val);
96
- } else {
97
- // If we need to resize out to support dynamic input shapes, we can't count
98
- // on normalized_shape matching the shape of the input or output. But we
99
- // don't need to modify normalized_shape because it's not used in this
100
- // function besides some checks
101
- ET_KERNEL_CHECK (
102
- ctx,
103
- resize_tensor (out, input.sizes ()) == Error::Ok,
104
- InvalidArgument,
105
- ret_val);
106
- }
118
+ Tensor::SizesType mean_rstd_sizes[kTensorDimensionLimit ];
119
+ size_t mean_rstd_ndim = 0 ;
120
+ get_layer_norm_out_target_size (
121
+ input, normalized_shape, mean_rstd_sizes, &mean_rstd_ndim);
122
+
123
+ ET_KERNEL_CHECK (
124
+ ctx,
125
+ resize_tensor (out, input.sizes ()) == Error::Ok,
126
+ InvalidArgument,
127
+ ret_val);
128
+
129
+ ET_KERNEL_CHECK (
130
+ ctx,
131
+ resize_tensor (mean_out, {mean_rstd_sizes, mean_rstd_ndim}) == Error::Ok,
132
+ InvalidArgument,
133
+ ret_val);
134
+
135
+ ET_KERNEL_CHECK (
136
+ ctx,
137
+ resize_tensor (rstd_out, {mean_rstd_sizes, mean_rstd_ndim}) == Error::Ok,
138
+ InvalidArgument,
139
+ ret_val);
107
140
108
141
ET_SWITCH_FLOAT_TYPES (input.scalar_type (), ctx, __func__, CTYPE, [&]() {
109
142
layer_norm<CTYPE>(
110
- input, weight. value () , bias. value () , eps, out, mean_out, rstd_out);
143
+ input, normalized_shape, weight, bias, eps, out, mean_out, rstd_out);
111
144
});
112
145
113
146
return ret_val;
0 commit comments