Skip to content

Commit 3644b60

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Dtype compliance: native_batch_norm
Reviewed By: SS-JIA Differential Revision: D48371009 fbshipit-source-id: 9b1b410438b817dd8e3ec97c9eb8f2c60c14a2ee
1 parent b35b665 commit 3644b60

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

kernels/portable/cpu/op_native_batch_norm.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,16 @@ std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_no_training_out(
4444
ET_KERNEL_CHECK(
4545
ctx,
4646
check_batch_norm_args(
47-
in, weight, bias, running_mean, running_var, momentum, eps, out),
47+
in,
48+
weight,
49+
bias,
50+
running_mean,
51+
running_var,
52+
momentum,
53+
eps,
54+
out,
55+
mean_out,
56+
var_out),
4857
InvalidArgument,
4958
ret_val);
5059
// For now, only support the default dim order

kernels/portable/cpu/util/normalization_ops_util.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,15 @@ bool check_batch_norm_args(
2323
const Tensor& running_var,
2424
double momentum,
2525
double eps,
26-
Tensor& out) {
26+
Tensor& out,
27+
Tensor& mean_out,
28+
Tensor& var_out) {
2729
// All tensors must be the same dtype
2830
ET_LOG_AND_RETURN_IF_FALSE(
2931
tensors_have_same_dtype(in, running_mean, running_var));
3032
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
33+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mean_out));
34+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, var_out));
3135
if (weight.has_value()) {
3236
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight.value()));
3337
}

kernels/portable/cpu/util/normalization_ops_util.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ bool check_batch_norm_args(
2121
const Tensor& running_var,
2222
double momentum,
2323
double eps,
24-
Tensor& out);
24+
Tensor& out,
25+
Tensor& mean_out,
26+
Tensor& var_out);
2527

2628
bool check_layer_norm_args(
2729
const Tensor& input,

0 commit comments

Comments
 (0)