Skip to content

Fix _native_batch_norm_legit_no_stats_out #6929

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 119 additions & 13 deletions kernels/portable/cpu/op_native_batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <tuple>

#include <executorch/kernels/portable/cpu/util/normalization_ops_util.h>
#include <executorch/kernels/portable/cpu/vec_ops.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>

Expand All @@ -18,6 +19,7 @@ namespace executor {
namespace native {

using Tensor = exec_aten::Tensor;
using SizesType = exec_aten::SizesType;

std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_no_training_out(
KernelRuntimeContext& ctx,
Expand Down Expand Up @@ -184,27 +186,131 @@ std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_no_stats_out(
Tensor& mean_out,
Tensor& invstd_out) {
(void)ctx;
(void)in;
(void)weight;
(void)bias;
(void)momentum;
(void)eps;
(void)training;

std::tuple<Tensor&, Tensor&, Tensor&> ret_val(out, mean_out, invstd_out);

ET_KERNEL_CHECK_MSG(
ET_KERNEL_CHECK(
ctx,
training == false,
check_batch_norm_args(
in,
weight,
bias,
exec_aten::optional<Tensor>(),
exec_aten::optional<Tensor>(),
momentum,
eps,
out,
mean_out,
invstd_out),
InvalidArgument,
ret_val,
"Portable kernels only support inference mode!");
ret_val);

ET_KERNEL_CHECK_MSG(
ET_KERNEL_CHECK(
ctx,
training == true,
is_contiguous_dim_order(in.dim_order().data(), in.dim_order().size()),
InvalidArgument,
ret_val,
"running_mean & running_var must be provided during inference!");
ret_val);

ET_KERNEL_CHECK(
ctx,
tensors_have_same_dim_order(in, out, mean_out, invstd_out),
InvalidArgument,
ret_val);

if (weight.has_value()) {
ET_KERNEL_CHECK(
ctx,
tensors_have_same_dim_order(in, weight.value()),
InvalidArgument,
ret_val);
}

if (bias.has_value()) {
ET_KERNEL_CHECK(
ctx,
tensors_have_same_dim_order(in, bias.value()),
InvalidArgument,
ret_val);
}

ET_KERNEL_CHECK(ctx, in.dim() >= 2, InvalidArgument, ret_val);

size_t N = in.size(0);
size_t C = in.size(1);
size_t inner = getTrailingDims(in, 1);
size_t elements_per_channel = N * inner;

ET_KERNEL_CHECK(
ctx,
resize_tensor(out, in.sizes()) == Error::Ok,
InvalidArgument,
ret_val);

ET_KERNEL_CHECK(
ctx,
resize_tensor(mean_out, {static_cast<SizesType>(C)}) == Error::Ok,
InvalidArgument,
ret_val);

ET_KERNEL_CHECK(
ctx,
resize_tensor(invstd_out, {static_cast<SizesType>(C)}) == Error::Ok,
InvalidArgument,
ret_val);

constexpr auto name = "_native_batch_norm_legit.no_stats_out";

ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] {
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
CTYPE* mean_data = mean_out.mutable_data_ptr<CTYPE>();
CTYPE* invstd_data = invstd_out.mutable_data_ptr<CTYPE>();

// Compute sum and sum of squares for each channel
for (size_t b = 0; b < N; ++b) {
const CTYPE* b_in_data = in_data + b * C * inner;
for (size_t c = 0; c < C; ++c) {
const CTYPE* x = b_in_data + c * inner;

CTYPE sum = reduce_add(x, inner);
CTYPE sq_sum = vec_powerf(x, inner);

mean_data[c] += sum;
invstd_data[c] += sq_sum;
}
}

// Compute mean and invstd for each channel
for (size_t c = 0; c < C; ++c) {
CTYPE mean = mean_data[c] / elements_per_channel;
// Var[x] = E[x^2] - E[x]^2
CTYPE var = invstd_data[c] / elements_per_channel - mean * mean;
CTYPE invstd = 1.0 / std::sqrt(var + eps);
mean_data[c] = mean;
invstd_data[c] = invstd;
}

for (size_t i = 0; i < N; ++i) {
for (size_t c = 0; c < C; ++c) {
CTYPE mean = mean_data[c];
CTYPE invstd = invstd_data[c];
CTYPE weight_val = 1;
if (weight.has_value()) {
weight_val = weight.value().const_data_ptr<CTYPE>()[c];
}
CTYPE bias_val = 0;
if (bias.has_value()) {
bias_val = bias.value().const_data_ptr<CTYPE>()[c];
}
for (size_t j = 0; j < inner; ++j) {
*out_data = (*in_data - mean) * invstd * weight_val + bias_val;
out_data++;
in_data++;
}
}
}
});

return ret_val;
}
Expand Down
36 changes: 23 additions & 13 deletions kernels/portable/cpu/util/normalization_ops_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,35 @@ bool check_batch_norm_args(
const Tensor& in,
const exec_aten::optional<Tensor>& weight,
const exec_aten::optional<Tensor>& bias,
const Tensor& running_mean,
const Tensor& running_var,
const exec_aten::optional<Tensor>& running_mean,
const exec_aten::optional<Tensor>& running_var,
double momentum,
double eps,
Tensor& out,
Tensor& mean_out,
Tensor& var_out) {
// All tensors must be the same dtype
ET_LOG_AND_RETURN_IF_FALSE(
tensors_have_same_dtype(in, running_mean, running_var));
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mean_out));
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, var_out));
if (weight.has_value()) {
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight.value()));
}
if (bias.has_value()) {
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, bias.value()));
}
if (running_mean.has_value()) {
ET_LOG_AND_RETURN_IF_FALSE(
tensors_have_same_dtype(in, running_mean.value()));
}
if (running_mean.has_value()) {
ET_LOG_AND_RETURN_IF_FALSE(
tensors_have_same_dtype(in, running_var.value()));
}
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mean_out));
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, var_out));

size_t C_dim = in.dim() >= 1 ? 1 : 0;
// All parameter tensors must be of dim 1 and have length equal to the
// channels dim of in
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_mean, 1));
ET_LOG_AND_RETURN_IF_FALSE(
tensors_have_same_size_at_dims(running_mean, 0, in, C_dim));
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_var, 1));
ET_LOG_AND_RETURN_IF_FALSE(
tensors_have_same_size_at_dims(running_var, 0, in, C_dim));
if (weight.has_value()) {
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight.value(), 1));
ET_LOG_AND_RETURN_IF_FALSE(
Expand All @@ -58,6 +58,16 @@ bool check_batch_norm_args(
ET_LOG_AND_RETURN_IF_FALSE(
tensors_have_same_size_at_dims(bias.value(), 0, in, C_dim));
}
if (running_mean.has_value()) {
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_mean.value(), 1));
ET_LOG_AND_RETURN_IF_FALSE(
tensors_have_same_size_at_dims(running_mean.value(), 0, in, C_dim));
}
if (running_var.has_value()) {
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_var.value(), 1));
ET_LOG_AND_RETURN_IF_FALSE(
tensors_have_same_size_at_dims(running_var.value(), 0, in, C_dim));
}

return true;
}
Expand Down
4 changes: 2 additions & 2 deletions kernels/portable/cpu/util/normalization_ops_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ bool check_batch_norm_args(
const Tensor& in,
const exec_aten::optional<Tensor>& weight,
const exec_aten::optional<Tensor>& bias,
const Tensor& running_mean,
const Tensor& running_var,
const exec_aten::optional<Tensor>& running_mean,
const exec_aten::optional<Tensor>& running_var,
double momentum,
double eps,
Tensor& out,
Expand Down
135 changes: 135 additions & 0 deletions kernels/test/op_native_batch_norm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,33 @@ class OpNativeBatchNormLegitOutTest : public OperatorTest {
}
};

class OpNativeBatchNormLegitNoStatsOutTest : public OperatorTest {
protected:
::std::tuple<exec_aten::Tensor&, exec_aten::Tensor&, exec_aten::Tensor&>
op_native_batch_norm_legit_no_stats_out(
const exec_aten::Tensor& input,
const exec_aten::optional<exec_aten::Tensor>& weight,
const exec_aten::optional<exec_aten::Tensor>& bias,
bool training,
double momentum,
double eps,
exec_aten::Tensor& out0,
exec_aten::Tensor& out1,
exec_aten::Tensor& out2) {
return torch::executor::aten::_native_batch_norm_legit_outf(
context_,
input,
weight,
bias,
training,
momentum,
eps,
out0,
out1,
out2);
}
};

TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest2D) {
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;

Expand Down Expand Up @@ -949,3 +976,111 @@ TEST_F(OpNativeBatchNormLegitOutTest, SampleAtomicTest2D) {
EXPECT_TENSOR_CLOSE(out1, out1_expected);
EXPECT_TENSOR_CLOSE(out2, out2_expected);
}

TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest2D) {
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;

exec_aten::Tensor input =
tfFloat.make({3, 4}, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121});
exec_aten::optional<exec_aten::Tensor> weight =
exec_aten::optional<exec_aten::Tensor>();
exec_aten::optional<exec_aten::Tensor> bias =
exec_aten::optional<exec_aten::Tensor>();
bool training = true;
double momentum = 1e-3;
double eps = 1e-5;
exec_aten::Tensor out0 = tfFloat.zeros({3, 4});
exec_aten::Tensor out1 = tfFloat.zeros({4});
exec_aten::Tensor out2 = tfFloat.zeros({4});
exec_aten::Tensor out0_expected = tfFloat.make(
{3, 4},
{-0.98058063,
-1.03422451,
-1.06904495,
-1.09332705,
-0.39223224,
-0.31822300,
-0.26726127,
-0.23017406,
1.37281299,
1.35244739,
1.33630610,
1.32350123});
exec_aten::Tensor out1_expected =
tfFloat.make({4}, {26.66666603, 35.66666794, 46.66666794, 59.66666794});
exec_aten::Tensor out2_expected =
tfFloat.make({4}, {0.03677177, 0.02983340, 0.02505574, 0.02157882});
op_native_batch_norm_legit_no_stats_out(
input, weight, bias, training, momentum, eps, out0, out1, out2);
EXPECT_TENSOR_CLOSE(out0, out0_expected);
EXPECT_TENSOR_CLOSE(out1, out1_expected);
EXPECT_TENSOR_CLOSE(out2, out2_expected);
}

TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest3D) {
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;

exec_aten::Tensor input = tfFloat.make(
{2, 3, 4}, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121,
144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529});
exec_aten::optional<exec_aten::Tensor> weight =
exec_aten::optional<exec_aten::Tensor>();
exec_aten::optional<exec_aten::Tensor> bias =
exec_aten::optional<exec_aten::Tensor>();
bool training = true;
double momentum = 1e-3;
double eps = 1e-5;
exec_aten::Tensor out0 = tfFloat.zeros({2, 3, 4});
exec_aten::Tensor out1 = tfFloat.zeros({3});
exec_aten::Tensor out2 = tfFloat.zeros({3});
exec_aten::Tensor out0_expected = tfFloat.make(
{2, 3, 4},
{-1.01045656, -0.99964952, -0.96722847, -0.91319335, -1.08850884,
-1.02468753, -0.94668359, -0.85449719, -1.12558389, -1.03595889,
-0.93578988, -0.82507670, 0.54575467, 0.81593025, 1.10771990,
1.42112350, 0.61339414, 0.84740579, 1.09560001, 1.35797679,
0.64582670, 0.86198103, 1.08867943, 1.32592189});
exec_aten::Tensor out1_expected = tfFloat.make({3}, {93.5, 169.5, 277.5});
exec_aten::Tensor out2_expected =
tfFloat.make({3}, {0.01080702, 0.00709126, 0.00527206});
op_native_batch_norm_legit_no_stats_out(
input, weight, bias, training, momentum, eps, out0, out1, out2);
EXPECT_TENSOR_CLOSE(out0, out0_expected);
EXPECT_TENSOR_CLOSE(out1, out1_expected);
EXPECT_TENSOR_CLOSE(out2, out2_expected);
}

TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest4D) {
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;

exec_aten::Tensor input =
tfFloat.make({2, 3, 2, 2}, {0, 1, 4, 9, 16, 25, 36, 49,
64, 81, 100, 121, 144, 169, 196, 225,
256, 289, 324, 361, 400, 441, 484, 529});
exec_aten::optional<exec_aten::Tensor> weight =
exec_aten::optional<exec_aten::Tensor>(
tfFloat.make({3}, {1.1, 0.7, 0.3}));
exec_aten::optional<exec_aten::Tensor> bias =
exec_aten::optional<exec_aten::Tensor>(
tfFloat.make({3}, {1.7, 2.2, 3.3}));
bool training = true;
double momentum = 1e-3;
double eps = 1e-5;
exec_aten::Tensor out0 = tfFloat.zeros({2, 3, 2, 2});
exec_aten::Tensor out1 = tfFloat.zeros({3});
exec_aten::Tensor out2 = tfFloat.zeros({3});
exec_aten::Tensor out0_expected = tfFloat.make(
{2, 3, 2, 2},
{0.58849782, 0.60038555, 0.63604873, 0.69548732, 1.43804383, 1.48271883,
1.53732157, 1.60185206, 2.96232486, 2.98921227, 3.01926303, 3.05247688,
2.30033016, 2.59752321, 2.91849184, 3.26323581, 2.62937593, 2.79318404,
2.96691990, 3.15058374, 3.49374819, 3.55859423, 3.62660384, 3.69777656});
exec_aten::Tensor out1_expected = tfFloat.make({3}, {93.5, 169.5, 277.5});
exec_aten::Tensor out2_expected =
tfFloat.make({3}, {0.01080702, 0.00709126, 0.00527206});
op_native_batch_norm_legit_no_stats_out(
input, weight, bias, training, momentum, eps, out0, out1, out2);
EXPECT_TENSOR_CLOSE(out0, out0_expected);
EXPECT_TENSOR_CLOSE(out1, out1_expected);
EXPECT_TENSOR_CLOSE(out2, out2_expected);
}
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,7 @@ ATEN_OPS = (
op_target(
name = "op_native_batch_norm",
deps = [
":vec_ops",
"//executorch/kernels/portable/cpu/util:normalization_ops_util",
],
),
Expand Down
Loading