Skip to content

Commit 69b4e75

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Add native_batch_norm
Summary: Add implementation of native_batch_norm in portable kernels Reviewed By: kimishpatel Differential Revision: D47878889 fbshipit-source-id: 3be2221c6df04fd73810a189484f1290b2908aca
1 parent 43fa887 commit 69b4e75

File tree

8 files changed

+1017
-0
lines changed

8 files changed

+1017
-0
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cmath>
10+
#include <tuple>
11+
12+
#include <executorch/kernels/portable/cpu/util/normalization_ops_util.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
#include <executorch/runtime/platform/assert.h>
15+
16+
namespace torch {
17+
namespace executor {
18+
namespace native {
19+
20+
using Tensor = exec_aten::Tensor;
21+
22+
std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_no_training_out(
23+
RuntimeContext& ctx,
24+
const Tensor& in,
25+
const exec_aten::optional<Tensor>& weight,
26+
const exec_aten::optional<Tensor>& bias,
27+
const Tensor& running_mean,
28+
const Tensor& running_var,
29+
double momentum,
30+
double eps,
31+
Tensor& out,
32+
Tensor& mean_out,
33+
Tensor& var_out) {
34+
(void)ctx;
35+
36+
ET_CHECK(resize_tensor(out, in.sizes()) == Error::Ok);
37+
38+
check_batch_norm_args(
39+
in, weight, bias, running_mean, running_var, momentum, eps, out);
40+
// For now, only support the default dim order
41+
ET_CHECK(is_default_dim_order(in.dim_order().data(), in.dim_order().size()));
42+
43+
size_t C_dim = in.dim() >= 1 ? 1 : 0;
44+
size_t C = in.size(C_dim);
45+
size_t outer = getLeadingDims(in, C_dim);
46+
size_t inner = getTrailingDims(in, C_dim);
47+
48+
ET_SWITCH_FLOAT_TYPES(
49+
in.scalar_type(), ctx, "native_batch_norm_legit_no_training", CTYPE, [&] {
50+
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
51+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
52+
53+
const CTYPE* const mean_data = running_mean.const_data_ptr<CTYPE>();
54+
const CTYPE* const var_data = running_var.const_data_ptr<CTYPE>();
55+
56+
for (size_t i = 0; i < outer; ++i) {
57+
for (size_t c = 0; c < C; ++c) {
58+
CTYPE mean = mean_data[c];
59+
CTYPE var = var_data[c];
60+
CTYPE invstd = 1.0 / std::sqrt(var + eps);
61+
CTYPE weight_val = 1;
62+
if (weight.has_value()) {
63+
weight_val = weight.value().const_data_ptr<CTYPE>()[c];
64+
}
65+
CTYPE bias_val = 0;
66+
if (bias.has_value()) {
67+
bias_val = bias.value().const_data_ptr<CTYPE>()[c];
68+
}
69+
for (size_t j = 0; j < inner; ++j) {
70+
*out_data = (*in_data - mean) * invstd * weight_val + bias_val;
71+
out_data++;
72+
in_data++;
73+
}
74+
}
75+
}
76+
});
77+
78+
return {out, mean_out, var_out};
79+
}
80+
81+
} // namespace native
82+
} // namespace executor
83+
} // namespace torch

kernels/portable/cpu/targets.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,12 @@ _ATEN_OPS = (
505505
":scalar_utils",
506506
],
507507
),
508+
op_target(
509+
name = "op_native_batch_norm",
510+
deps = [
511+
"//executorch/kernels/portable/cpu/util:normalization_ops_util",
512+
],
513+
),
508514
op_target(
509515
name = "op_native_layer_norm",
510516
deps = [
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cstring>
10+
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
13+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
14+
#include <executorch/runtime/platform/assert.h>
15+
16+
namespace torch {
17+
namespace executor {
18+
19+
using Tensor = exec_aten::Tensor;
20+
21+
void check_batch_norm_args(
22+
const Tensor& in,
23+
const exec_aten::optional<Tensor>& weight,
24+
const exec_aten::optional<Tensor>& bias,
25+
const Tensor& running_mean,
26+
const Tensor& running_var,
27+
double momentum,
28+
double eps,
29+
Tensor& out) {
30+
// All tensors must be the same dtype
31+
ET_CHECK_SAME_DTYPE3(in, running_mean, running_var);
32+
ET_CHECK_SAME_DTYPE2(in, out);
33+
if (weight.has_value()) {
34+
ET_CHECK_SAME_DTYPE2(in, weight.value());
35+
}
36+
if (bias.has_value()) {
37+
ET_CHECK_SAME_DTYPE2(in, bias.value());
38+
}
39+
40+
size_t C_dim = in.dim() >= 1 ? 1 : 0;
41+
// All parameter tensors must be of dim 1 and have length equal to the
42+
// channels dim of in
43+
ET_CHECK(running_mean.dim() == 1 && running_mean.size(0) == in.size(C_dim));
44+
ET_CHECK(running_var.dim() == 1 && running_var.size(0) == in.size(C_dim));
45+
if (weight.has_value()) {
46+
ET_CHECK(
47+
weight.value().dim() == 1 && weight.value().size(0) == in.size(C_dim));
48+
}
49+
if (bias.has_value()) {
50+
ET_CHECK(bias.value().dim() == 1 && bias.value().size(0) == in.size(C_dim));
51+
}
52+
}
53+
54+
} // namespace executor
55+
} // namespace torch
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
16+
void check_batch_norm_args(
17+
const Tensor& in,
18+
const exec_aten::optional<Tensor>& weight,
19+
const exec_aten::optional<Tensor>& bias,
20+
const Tensor& running_mean,
21+
const Tensor& running_var,
22+
double momentum,
23+
double eps,
24+
Tensor& out);
25+
26+
} // namespace executor
27+
} // namespace torch

kernels/portable/cpu/util/targets.bzl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,19 @@ def define_common_targets():
5050
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."],
5151
)
5252

53+
runtime.cxx_library(
54+
name = "normalization_ops_util",
55+
srcs = ["normalization_ops_util.cpp"],
56+
exported_headers = [
57+
"normalization_ops_util.h",
58+
],
59+
compiler_flags = ["-Wno-missing-prototypes"],
60+
deps = [
61+
"//executorch/runtime/kernel:kernel_includes",
62+
],
63+
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."],
64+
)
65+
5366
runtime.cxx_library(
5467
name = "transpose_util",
5568
exported_headers = [

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
- arg_meta: null
2323
kernel_name: torch::executor::log_softmax_out
2424

25+
- op: _native_batch_norm_legit_no_training.out
26+
kernels:
27+
- arg_meta: null
28+
kernel_name: torch::executor::_native_batch_norm_legit_no_training_out
29+
2530
- op: _softmax.out
2631
kernels:
2732
- arg_meta: null

0 commit comments

Comments
 (0)