Skip to content

Commit 64e681c

Browse files
swolchokYIWENX14
authored andcommitted
Support Half/BFloat16 in logical_not (#7827)
Partial fix for #7748.
1 parent 96f5807 commit 64e681c

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

kernels/portable/cpu/op_logical_not.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ logical_not_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
3333

3434
ET_KERNEL_CHECK(ctx, tensors_have_same_shape(in, out), InvalidArgument, out);
3535

36-
ET_SWITCH_REAL_TYPES_AND(
37-
Bool, in.scalar_type(), ctx, "logical_not.out", CTYPE_IN, [&] {
38-
ET_SWITCH_REAL_TYPES_AND(
39-
Bool, out.scalar_type(), ctx, "logical_not.out", CTYPE_OUT, [&] {
36+
ET_SWITCH_REALHBBF16_TYPES(
37+
in.scalar_type(), ctx, "logical_not.out", CTYPE_IN, [&] {
38+
ET_SWITCH_REALHBBF16_TYPES(
39+
out.scalar_type(), ctx, "logical_not.out", CTYPE_OUT, [&] {
4040
apply_unary_map_fn(
4141
[](const CTYPE_IN val_in) {
4242
return static_cast<CTYPE_OUT>(!static_cast<bool>(val_in));

kernels/test/op_logical_not_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ TEST_F(OpLogicalNotOutTest, AllTypePasses) {
122122
test_logical_not_out<ScalarType::INPUT_DTYPE, ScalarType::OUTPUT_DTYPE>();
123123

124124
#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \
125-
ET_FORALL_REAL_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
125+
ET_FORALL_REALHBBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
126126

127-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
127+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
128128
#undef TEST_ENTRY
129129
#undef TEST_KERNEL
130130
}

runtime/core/exec_aten/util/scalar_type_util.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,10 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType)
252252
_(ANOTHER_INPUT1, ANOTHER_INPUT2, ::executorch::aten::Half, Half) \
253253
_(ANOTHER_INPUT1, ANOTHER_INPUT2, ::executorch::aten::BFloat16, BFloat16)
254254

255+
#define ET_FORALL_REALHBBF16_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \
256+
ET_FORALL_REALHBF16_TYPES_WITH2(ANOTHER_INPUT2, ANOTHER_INPUT2, _) \
257+
_(ANOTHER_INPUT1, ANOTHER_INPUT2, bool, Bool)
258+
255259
// For macros that take `SCALARTYPEn` parameters, those parameters should be
256260
// an unquoted/unqualified enumerator name like `Int` or `Float`.
257261
#define ET_FORALL_REAL_TYPES_AND(SCALARTYPE, _) \

0 commit comments

Comments
 (0)