Skip to content

Commit 492b1b1

Browse files
swolchokYIWENX14
authored andcommitted
Support Half/BFloat16 in masked_fill (#7828)
Partial fix for #7748.
1 parent 64e681c commit 492b1b1

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

kernels/portable/cpu/op_masked_fill.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ Tensor& masked_fill_scalar_out(
4242
ET_KERNEL_CHECK(
4343
ctx, tensors_have_same_dim_order(in, mask, out), InvalidArgument, out);
4444

45-
ET_SWITCH_REAL_TYPES_AND(
46-
Bool, in_type, ctx, "masked_fill.Scalar_out", CTYPE, [&]() {
45+
ET_SWITCH_REALHBBF16_TYPES(
46+
in_type, ctx, "masked_fill.Scalar_out", CTYPE, [&]() {
4747
ET_SWITCH_REAL_TYPES_AND(
4848
Bool, val_type, ctx, "masked_fill.Scalar_out", CTYPE_VAL, [&]() {
4949
CTYPE_VAL value_v;

kernels/test/op_masked_fill_test.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,11 @@ TEST_F(OpMaskedFillTest, IntTensorFloatAlphaDies) {
114114
tf.ones(sizes), tf.ones(sizes), /*alpha=*/.7, out));
115115
}
116116

117-
TEST_F(OpMaskedFillTest, FloatTensors) {
118-
test_floating_point_masked_fill_scalar_out<ScalarType::Float>();
117+
TEST_F(OpMaskedFillTest, FloatingPointTensors) {
118+
#define TEST_ENTRY(ctype, dtype) \
119+
test_floating_point_masked_fill_scalar_out<ScalarType::dtype>();
120+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
121+
#undef TEST_ENTRY
119122
}
120123

121124
TEST_F(OpMaskedFillTest, DoubleTensors) {

0 commit comments

Comments
 (0)