Skip to content

Commit 4bc2029

Browse files
authored
Support Half/BFloat16 in mean (#7837)
Partial fix for #7748.
1 parent 5d1595a commit 4bc2029

File tree

2 files changed

+32
-19
lines changed

2 files changed

+32
-19
lines changed

kernels/portable/cpu/op_mean.cpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,24 @@ Tensor& mean_dim_out(
4444
InvalidArgument,
4545
out);
4646

47-
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
48-
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
49-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
50-
const size_t num = get_reduced_dim_product(in, dim_list);
51-
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
52-
CTYPE_OUT sum = 0;
53-
if (in.numel() > 0) {
54-
sum = map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
55-
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
56-
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
57-
in,
58-
dim_list,
59-
out_ix);
60-
}
61-
out_data[out_ix] = sum / static_cast<float>(num);
62-
}
63-
});
47+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
48+
ET_SWITCH_FLOATHBF16_TYPES(
49+
out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
50+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
51+
const size_t num = get_reduced_dim_product(in, dim_list);
52+
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
53+
CTYPE_OUT sum = 0;
54+
if (in.numel() > 0) {
55+
sum = map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
56+
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
57+
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
58+
in,
59+
dim_list,
60+
out_ix);
61+
}
62+
out_data[out_ix] = sum / static_cast<float>(num);
63+
}
64+
});
6465
});
6566

6667
return out;

kernels/test/op_mean_test.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,18 @@ class OpMeanOutTest : public OperatorTest {
238238
}
239239
};
240240

241+
template <>
242+
void OpMeanOutTest::
243+
test_mean_dim_out_dtype<ScalarType::Bool, ScalarType::Half>() {
244+
test_mean_dim_out_bool<ScalarType::Half>();
245+
}
246+
247+
template <>
248+
void OpMeanOutTest::
249+
test_mean_dim_out_dtype<ScalarType::Bool, ScalarType::BFloat16>() {
250+
test_mean_dim_out_bool<ScalarType::BFloat16>();
251+
}
252+
241253
template <>
242254
void OpMeanOutTest::
243255
test_mean_dim_out_dtype<ScalarType::Bool, ScalarType::Float>() {
@@ -331,9 +343,9 @@ TEST_F(OpMeanOutTest, AllRealInputFloatOutputPasses) {
331343
test_mean_dim_out_dtype<ScalarType::INPUT_DTYPE, ScalarType::OUTPUT_DTYPE>();
332344

333345
#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \
334-
ET_FORALL_FLOAT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
346+
ET_FORALL_FLOATHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
335347

336-
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
348+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
337349
#undef TEST_ENTRY
338350
#undef TEST_KERNEL
339351
}

0 commit comments

Comments
 (0)