Skip to content

Commit 403c1ea

Browse files
committed
Update
[ghstack-poisoned]
1 parent 9932759 commit 403c1ea

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

kernels/portable/cpu/op_mean.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,22 +45,23 @@ Tensor& mean_dim_out(
4545
out);
4646

4747
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
48-
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
49-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
50-
const size_t num = get_reduced_dim_product(in, dim_list);
51-
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
52-
CTYPE_OUT sum = 0;
53-
if (in.numel() > 0) {
54-
sum = map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
55-
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
56-
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
57-
in,
58-
dim_list,
59-
out_ix);
60-
}
61-
out_data[out_ix] = sum / static_cast<float>(num);
62-
}
63-
});
48+
ET_SWITCH_FLOATHBF16_TYPES(
49+
out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
50+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
51+
const size_t num = get_reduced_dim_product(in, dim_list);
52+
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
53+
CTYPE_OUT sum = 0;
54+
if (in.numel() > 0) {
55+
sum = map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
56+
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
57+
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
58+
in,
59+
dim_list,
60+
out_ix);
61+
}
62+
out_data[out_ix] = sum / static_cast<float>(num);
63+
}
64+
});
6465
});
6566

6667
return out;

0 commit comments

Comments
 (0)