Skip to content

Commit 5db40f2

Browse files
authored
Support Half/BFloat16 in nonzero (#7850)
Partial fix for #7748.
1 parent e000b22 commit 5db40f2

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

kernels/portable/cpu/op_nonzero.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,9 @@ Tensor& nonzero_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
8888

8989
ET_KERNEL_CHECK(ctx, check_nonzero_args(in, out), InvalidArgument, out);
9090

91-
ET_SWITCH_REAL_TYPES_AND(
92-
Bool, in.scalar_type(), ctx, "nonzero.out", CTYPE, [&] {
93-
nonzero<CTYPE>(ctx, in, out);
94-
});
91+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "nonzero.out", CTYPE, [&] {
92+
nonzero<CTYPE>(ctx, in, out);
93+
});
9594

9695
return out;
9796
}

kernels/test/op_nonzero_test.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@ class OpNonzeroTest : public OperatorTest {
2828
void test_dtype() {
2929
TensorFactory<DTYPE> tf_input;
3030
TensorFactory<ScalarType::Long> tf_long;
31-
// clang-format off
32-
Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 0,
33-
2, 4});
34-
// clang-format on
31+
Tensor a = tf_input.make(
32+
/*sizes=*/{2, 2}, /*data=*/{CTYPE(2), CTYPE(0), CTYPE(2), CTYPE(4)});
3533
Tensor out = tf_long.zeros({3, 2});
3634

3735
op_nonzero_out(a, out);
@@ -45,7 +43,7 @@ class OpNonzeroTest : public OperatorTest {
4543

4644
TEST_F(OpNonzeroTest, AllDtypesSupported) {
4745
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
48-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
46+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
4947
#undef TEST_ENTRY
5048
}
5149

0 commit comments

Comments
 (0)