Skip to content

Commit f8b0e91

Browse files
[ET][Portable] Fix op native_batch_norm
Resize `mean_out` & `var_out` Differential Revision: [D50081049](https://our.internmc.facebook.com/intern/diff/D50081049/) ghstack-source-id: 203404908 Pull Request resolved: #725
1 parent 9bdf02f commit f8b0e91

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

kernels/portable/cpu/op_native_batch_norm.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_no_training_out(
4141
InvalidArgument,
4242
ret_val);
4343

44+
ET_KERNEL_CHECK(
45+
ctx, resize_tensor(mean_out, {0}) == Error::Ok, InvalidArgument, ret_val);
46+
47+
ET_KERNEL_CHECK(
48+
ctx, resize_tensor(var_out, {0}) == Error::Ok, InvalidArgument, ret_val);
49+
4450
ET_KERNEL_CHECK(
4551
ctx,
4652
check_batch_norm_args(
@@ -56,6 +62,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_no_training_out(
5662
var_out),
5763
InvalidArgument,
5864
ret_val);
65+
5966
// For now, only support the default dim order
6067
ET_KERNEL_CHECK(
6168
ctx,

kernels/portable/cpu/util/normalization_ops_util.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ bool check_batch_norm_args(
4545
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_mean, 1));
4646
ET_LOG_AND_RETURN_IF_FALSE(
4747
tensors_have_same_size_at_dims(running_mean, 0, in, C_dim));
48+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_var, 1));
49+
ET_LOG_AND_RETURN_IF_FALSE(
50+
tensors_have_same_size_at_dims(running_var, 0, in, C_dim));
4851
if (weight.has_value()) {
4952
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight.value(), 1));
5053
ET_LOG_AND_RETURN_IF_FALSE(

0 commit comments

Comments
 (0)