Skip to content

Commit b98d6a7

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Fix & cleanup op layer_norm (#707)
Summary: Pull Request resolved: #707 ghstack-source-id: 203341584 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D49848492 fbshipit-source-id: 1bc4c5f9f766231f3ec1488d046cc66f29bfd03f
1 parent bbdf579 commit b98d6a7

File tree

4 files changed

+121
-45
lines changed

4 files changed

+121
-45
lines changed

kernels/portable/cpu/op_native_layer_norm.cpp

Lines changed: 67 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,37 +23,65 @@ namespace {
2323
template <typename CTYPE>
2424
void layer_norm(
2525
const Tensor& input,
26-
const Tensor& weight,
27-
const Tensor& bias,
26+
IntArrayRef normalized_shape,
27+
const optional<Tensor>& weight,
28+
const optional<Tensor>& bias,
2829
CTYPE eps,
2930
Tensor& out,
3031
Tensor& mean,
3132
Tensor& rstd) {
32-
const CTYPE* input_data = input.const_data_ptr<CTYPE>();
33-
const CTYPE* weight_data = weight.const_data_ptr<CTYPE>();
34-
const CTYPE* bias_data = bias.const_data_ptr<CTYPE>();
33+
size_t dim = input.dim() - normalized_shape.size();
34+
size_t dim_size = input.size(dim);
35+
36+
size_t leading = getLeadingDims(input, dim);
37+
size_t normalized = getTrailingDims(input, dim) * dim_size;
38+
39+
if (leading == 0) {
40+
return;
41+
}
42+
3543
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
3644
CTYPE* mean_data = mean.mutable_data_ptr<CTYPE>();
3745
CTYPE* rstd_data = rstd.mutable_data_ptr<CTYPE>();
3846

39-
size_t dim = input.size(input.dim() - 1);
47+
if (normalized == 0) {
48+
for (int i = 0; i < leading; ++i) {
49+
mean_data[i] = static_cast<CTYPE>(0);
50+
rstd_data[i] = static_cast<CTYPE>(NAN);
51+
}
52+
return;
53+
}
4054

41-
size_t leading_dim = getLeadingDims(input, input.dim() - 1);
55+
const CTYPE* input_data = input.const_data_ptr<CTYPE>();
56+
const CTYPE* weight_data;
57+
if (weight.has_value()) {
58+
weight_data = weight.value().const_data_ptr<CTYPE>();
59+
} else {
60+
weight_data = nullptr;
61+
}
62+
const CTYPE* bias_data;
63+
if (bias.has_value()) {
64+
bias_data = bias.value().const_data_ptr<CTYPE>();
65+
} else {
66+
bias_data = nullptr;
67+
}
4268

43-
for (int i = 0; i < leading_dim; ++i) {
44-
const CTYPE* x = input_data + i * dim;
45-
CTYPE* y = out_data + i * dim;
69+
for (int i = 0; i < leading; ++i) {
70+
const CTYPE* x = input_data + i * normalized;
71+
CTYPE* y = out_data + i * normalized;
4672

4773
// compute E[X] and Var[x] = E[x^2] - E[x]^2
48-
CTYPE sum = reduce_add(x, dim);
49-
CTYPE sq_sum = vec_powerf(x, dim);
50-
CTYPE mean_value = sum / dim;
51-
CTYPE variance = sq_sum / dim - mean_value * mean_value;
74+
CTYPE sum = reduce_add(x, normalized);
75+
CTYPE sq_sum = vec_powerf(x, normalized);
76+
CTYPE mean_value = sum / normalized;
77+
CTYPE variance = sq_sum / normalized - mean_value * mean_value;
5278
CTYPE std = std::sqrt(variance + eps);
5379

5480
// Calculate the elements of output
55-
for (int j = 0; j < dim; ++j) {
56-
y[j] = (x[j] - mean_value) / std * weight_data[j] + bias_data[j];
81+
for (int j = 0; j < normalized; ++j) {
82+
CTYPE w = weight_data ? weight_data[j] : static_cast<CTYPE>(1);
83+
CTYPE b = bias_data ? bias_data[j] : static_cast<CTYPE>(0);
84+
y[j] = (x[j] - mean_value) / std * w + b;
5785
}
5886

5987
mean_data[i] = mean_value;
@@ -87,27 +115,32 @@ std::tuple<Tensor&, Tensor&, Tensor&> native_layer_norm_out(
87115
InvalidArgument,
88116
ret_val);
89117

90-
if (input.sizes() == out.sizes()) {
91-
ET_KERNEL_CHECK(
92-
ctx,
93-
normalized_shape[0] == input.sizes()[input.dim() - 1],
94-
InvalidArgument,
95-
ret_val);
96-
} else {
97-
// If we need to resize out to support dynamic input shapes, we can't count
98-
// on normalized_shape matching the shape of the input or output. But we
99-
// don't need to modify normalized_shape because it's not used in this
100-
// function besides some checks
101-
ET_KERNEL_CHECK(
102-
ctx,
103-
resize_tensor(out, input.sizes()) == Error::Ok,
104-
InvalidArgument,
105-
ret_val);
106-
}
118+
Tensor::SizesType mean_rstd_sizes[kTensorDimensionLimit];
119+
size_t mean_rstd_ndim = 0;
120+
get_layer_norm_out_target_size(
121+
input, normalized_shape, mean_rstd_sizes, &mean_rstd_ndim);
122+
123+
ET_KERNEL_CHECK(
124+
ctx,
125+
resize_tensor(out, input.sizes()) == Error::Ok,
126+
InvalidArgument,
127+
ret_val);
128+
129+
ET_KERNEL_CHECK(
130+
ctx,
131+
resize_tensor(mean_out, {mean_rstd_sizes, mean_rstd_ndim}) == Error::Ok,
132+
InvalidArgument,
133+
ret_val);
134+
135+
ET_KERNEL_CHECK(
136+
ctx,
137+
resize_tensor(rstd_out, {mean_rstd_sizes, mean_rstd_ndim}) == Error::Ok,
138+
InvalidArgument,
139+
ret_val);
107140

108141
ET_SWITCH_FLOAT_TYPES(input.scalar_type(), ctx, __func__, CTYPE, [&]() {
109142
layer_norm<CTYPE>(
110-
input, weight.value(), bias.value(), eps, out, mean_out, rstd_out);
143+
input, normalized_shape, weight, bias, eps, out, mean_out, rstd_out);
111144
});
112145

113146
return ret_val;

kernels/portable/cpu/util/normalization_ops_util.cpp

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,27 +60,62 @@ bool check_batch_norm_args(
6060
}
6161

6262
bool check_layer_norm_args(
63-
const Tensor& input,
63+
const Tensor& in,
6464
IntArrayRef normalized_shape,
6565
const exec_aten::optional<Tensor>& weight,
6666
const exec_aten::optional<Tensor>& bias,
6767
Tensor& out,
6868
Tensor& mean_out,
6969
Tensor& rstd_out) {
70-
ET_LOG_AND_RETURN_IF_FALSE(normalized_shape.size() == 1);
71-
ET_LOG_AND_RETURN_IF_FALSE(weight.has_value());
70+
size_t ndim = normalized_shape.size();
71+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
72+
ndim >= 1,
73+
"Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element.");
74+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
75+
in.dim() >= ndim,
76+
"Expected input tensor to have rank >= the length of normalized_shape.");
77+
size_t shift = in.dim() - ndim;
78+
for (size_t d = 0; d < ndim; ++d) {
79+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
80+
in.size(d + shift) == normalized_shape[d],
81+
"Expected normalized_shape to match the sizes of input's rightmost dimensions.");
82+
}
83+
exec_aten::SizesType shape[ndim];
84+
for (size_t i = 0; i < ndim; ++i) {
85+
shape[i] = static_cast<exec_aten::SizesType>(normalized_shape[i]);
86+
}
87+
7288
if (weight.has_value()) {
73-
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, weight.value()));
89+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight.value()));
90+
ET_LOG_AND_RETURN_IF_FALSE(
91+
tensor_has_expected_size(weight.value(), {shape, ndim}));
7492
}
7593
if (bias.has_value()) {
76-
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, bias.value()));
94+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, bias.value()));
95+
ET_LOG_AND_RETURN_IF_FALSE(
96+
tensor_has_expected_size(bias.value(), {shape, ndim}));
7797
}
78-
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, out));
79-
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, mean_out));
80-
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(input, rstd_out));
81-
ET_LOG_AND_RETURN_IF_FALSE(input.dim() == out.dim());
98+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
99+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mean_out));
100+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, rstd_out));
82101
return true;
83102
}
84103

104+
void get_layer_norm_out_target_size(
105+
const Tensor& in,
106+
IntArrayRef normalized_shape,
107+
Tensor::SizesType* mean_rstd_sizes,
108+
size_t* mean_rstd_ndim) {
109+
*mean_rstd_ndim = in.dim();
110+
111+
for (size_t d = 0; d < in.dim(); ++d) {
112+
if (d < in.dim() - normalized_shape.size()) {
113+
mean_rstd_sizes[d] = in.size(d);
114+
} else {
115+
mean_rstd_sizes[d] = 1;
116+
}
117+
}
118+
}
119+
85120
} // namespace executor
86121
} // namespace torch

kernels/portable/cpu/util/normalization_ops_util.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,11 @@ bool check_layer_norm_args(
3434
Tensor& mean_out,
3535
Tensor& rstd_out);
3636

37+
void get_layer_norm_out_target_size(
38+
const Tensor& in,
39+
IntArrayRef normalized_shape,
40+
Tensor::SizesType* mean_rstd_sizes,
41+
size_t* mean_rstd_ndim);
42+
3743
} // namespace executor
3844
} // namespace torch

kernels/test/op_native_layer_norm_test.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@ void run_test_cases(std::vector<NativeLayerNormTestCase<DTYPE>> test_cases) {
9191
Tensor weight = tf.make(test_case.normalized_shape, test_case.weight_data);
9292
Tensor bias = tf.make(test_case.normalized_shape, test_case.bias_data);
9393
Tensor out0 = tf.zeros(test_case.sizes);
94-
Tensor out1 = tf.zeros(test_case.sizes);
95-
Tensor out2 = tf.zeros(test_case.sizes);
94+
Tensor out1 = tf.zeros(
95+
test_case.sizes, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
96+
Tensor out2 = tf.zeros(
97+
test_case.sizes, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
9698
auto normalized_shape_vec = std::vector<int64_t>(
9799
test_case.normalized_shape.begin(), test_case.normalized_shape.end());
98100
auto normalized_shape = exec_aten::ArrayRef<int64_t>(

0 commit comments

Comments
 (0)