Skip to content

Commit 82ca9cf

Browse files
swolchokYIWENX14
authored andcommitted
Support Half/BFloat16 in fill (#7809)
Partial fix for #7748.
1 parent abc08a3 commit 82ca9cf

File tree

3 files changed

+16
-13
lines changed

3 files changed

+16
-13
lines changed

kernels/portable/cpu/op_fill.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Tensor& fill_scalar_out(
4242
out,
4343
"Failed to resize output tensor.");
4444

45-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "fill.Scalar_out", CTYPE_A, [&] {
45+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "fill.Scalar_out", CTYPE_A, [&] {
4646
CTYPE_A b_casted;
4747
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "fill.Scalar_out", CTYPE_B, [&] {
4848
CTYPE_B b_val;
@@ -87,14 +87,14 @@ Tensor& fill_tensor_out(
8787
out,
8888
"Failed to resize output tensor.");
8989

90-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "fill.Tensor_out", CTYPE_A, [&] {
90+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "fill.Tensor_out", CTYPE_A, [&] {
9191
CTYPE_A b_casted;
92-
ET_SWITCH_REAL_TYPES_AND(
93-
Bool, b_type, ctx, "fill.Tensor_out", CTYPE_B, [&] {
94-
CTYPE_B b_val;
95-
extract_scalar_tensor(b, &b_val);
96-
b_casted = static_cast<CTYPE_A>(b_val);
97-
});
92+
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "fill.Tensor_out", CTYPE_B, [&] {
93+
CTYPE_B b_val;
94+
ET_DCHECK_MSG(
95+
extract_scalar_tensor(b, &b_val), "extract_scalar_tensor failed!");
96+
b_casted = static_cast<CTYPE_A>(b_val);
97+
});
9898

9999
apply_unary_map_fn(
100100
[b_casted](const CTYPE_A val_a) { return b_casted; },

kernels/test/op_fill_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,15 @@ class OpFillTest : public OperatorTest {
9292
TEST_FILL_OUT(test_fill_scalar_out, DTYPE); \
9393
}
9494

95-
ET_FORALL_REAL_TYPES_AND(Bool, GENERATE_SCALAR_INPUT_SUPPORT_TEST)
95+
ET_FORALL_REALHBBF16_TYPES(GENERATE_SCALAR_INPUT_SUPPORT_TEST)
9696

9797
// Create input support tests for tensor variant.
9898
#define GENERATE_TENSOR_INPUT_SUPPORT_TEST(_, DTYPE) \
9999
TEST_F(OpFillTest, DTYPE##TensorInputSupport) { \
100100
TEST_FILL_OUT(test_fill_tensor_out, DTYPE); \
101101
}
102102

103-
ET_FORALL_REAL_TYPES_AND(Bool, GENERATE_TENSOR_INPUT_SUPPORT_TEST)
103+
ET_FORALL_REALHBBF16_TYPES(GENERATE_TENSOR_INPUT_SUPPORT_TEST)
104104

105105
TEST_F(OpFillTest, MismatchedOtherPropertiesDies) {
106106
TensorFactory<ScalarType::Int> tf;

runtime/core/exec_aten/util/tensor_util.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,8 +1062,11 @@ bool extract_scalar_tensor(executorch::aten::Tensor tensor, INT_T* out_val) {
10621062
*/
10631063
template <
10641064
typename FLOAT_T,
1065-
typename std::enable_if<std::is_floating_point<FLOAT_T>::value, bool>::
1066-
type = true>
1065+
typename std::enable_if<
1066+
std::is_floating_point_v<FLOAT_T> ||
1067+
std::is_same_v<FLOAT_T, exec_aten::BFloat16> ||
1068+
std::is_same_v<FLOAT_T, exec_aten::Half>,
1069+
bool>::type = true>
10671070
bool extract_scalar_tensor(executorch::aten::Tensor tensor, FLOAT_T* out_val) {
10681071
if (tensor.numel() != 1) {
10691072
return false;
@@ -1083,7 +1086,7 @@ bool extract_scalar_tensor(executorch::aten::Tensor tensor, FLOAT_T* out_val) {
10831086
}
10841087

10851088
switch (tensor.scalar_type()) {
1086-
ET_FORALL_REAL_TYPES(CASE_REAL_DTYPE);
1089+
ET_FORALL_REALHBF16_TYPES(CASE_REAL_DTYPE);
10871090
default:
10881091
return false;
10891092
}

0 commit comments

Comments
 (0)