Skip to content

Commit 9932759

Browse files
committed
Update
[ghstack-poisoned]
1 parent b04912f commit 9932759

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

kernels/portable/cpu/op_mean.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ Tensor& mean_dim_out(
4444
InvalidArgument,
4545
out);
4646

47-
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
48-
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
47+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
48+
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
4949
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
5050
const size_t num = get_reduced_dim_product(in, dim_list);
5151
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {

kernels/test/op_mean_test.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,18 @@ class OpMeanOutTest : public OperatorTest {
238238
}
239239
};
240240

241+
template <>
242+
void OpMeanOutTest::
243+
test_mean_dim_out_dtype<ScalarType::Bool, ScalarType::Half>() {
244+
test_mean_dim_out_bool<ScalarType::Half>();
245+
}
246+
247+
template <>
248+
void OpMeanOutTest::
249+
test_mean_dim_out_dtype<ScalarType::Bool, ScalarType::BFloat16>() {
250+
test_mean_dim_out_bool<ScalarType::BFloat16>();
251+
}
252+
241253
template <>
242254
void OpMeanOutTest::
243255
test_mean_dim_out_dtype<ScalarType::Bool, ScalarType::Float>() {
@@ -331,9 +343,9 @@ TEST_F(OpMeanOutTest, AllRealInputFloatOutputPasses) {
331343
test_mean_dim_out_dtype<ScalarType::INPUT_DTYPE, ScalarType::OUTPUT_DTYPE>();
332344

333345
#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \
334-
ET_FORALL_FLOAT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
346+
ET_FORALL_FLOATHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
335347

336-
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
348+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
337349
#undef TEST_ENTRY
338350
#undef TEST_KERNEL
339351
}

0 commit comments

Comments
 (0)