Skip to content

Commit 96f5807

Browse files
swolchokYIWENX14
authored andcommitted
Support Half/BFloat16 in log_softmax (#7826)
Partial fix for #7748.
1 parent 1d88092 commit 96f5807

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

kernels/portable/cpu/op_log_softmax.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Tensor& log_softmax_out(
4242
// Adjust for negative dim
4343
dim = dim < 0 ? dim + nonzero_dim(in) : dim;
4444

45-
ET_SWITCH_FLOAT_TYPES(
45+
ET_SWITCH_FLOATHBF16_TYPES(
4646
in.scalar_type(), ctx, "_log_softmax.out", CTYPE, [&]() {
4747
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
4848
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();

kernels/test/op_log_softmax_test.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,15 @@ class OpLogSoftmaxOutTest : public OperatorTest {
6262
});
6363
// clang-format on
6464

65-
EXPECT_TENSOR_CLOSE(out, expected);
65+
if constexpr (DTYPE == ScalarType::BFloat16) {
66+
EXPECT_TENSOR_CLOSE_WITH_TOL(
67+
out,
68+
expected,
69+
1e-2,
70+
executorch::runtime::testing::internal::kDefaultAtol);
71+
} else {
72+
EXPECT_TENSOR_CLOSE(out, expected);
73+
}
6674
}
6775
};
6876

@@ -88,11 +96,9 @@ TEST_F(OpLogSoftmaxOutTest, AllDtypesSupported) {
8896
GTEST_SKIP() << "This kernel does not support dtype double";
8997
}
9098

91-
test_dtype<float, ScalarType::Float>();
92-
test_dtype<double, ScalarType::Double>();
93-
// TODO: Also add tests for half, complex, quantized, and other types. Easiest
94-
// way to do that would be to make TensorFactory support zeros() and ones()
95-
// for those types.
99+
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
100+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY)
101+
#undef TEST_ENTRY
96102
}
97103

98104
TEST_F(OpLogSoftmaxOutTest, MismatchedDimensionsDies) {

0 commit comments

Comments
 (0)