Skip to content

Commit eee2bf1

Browse files
Add bf16 support to unary_ufunc_realh
Differential Revision: D71839099 Pull Request resolved: #9599
1 parent 7d35c68 commit eee2bf1

File tree

9 files changed

+68
-64
lines changed

9 files changed

+68
-64
lines changed

kernels/portable/cpu/op_ceil.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace native {
1717
using executorch::aten::Tensor;
1818

1919
Tensor& ceil_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
20-
return internal::unary_ufunc_realh(std::ceil, ctx, in, out);
20+
return internal::unary_ufunc_realhbf16(std::ceil, ctx, in, out);
2121
}
2222

2323
} // namespace native

kernels/portable/cpu/op_floor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace native {
1717
using executorch::aten::Tensor;
1818

1919
Tensor& floor_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
20-
return internal::unary_ufunc_realh(std::floor, ctx, in, out);
20+
return internal::unary_ufunc_realhbf16(std::floor, ctx, in, out);
2121
}
2222

2323
} // namespace native

kernels/portable/cpu/op_trunc.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace executor {
1515
namespace native {
1616

1717
Tensor& trunc_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
18-
return internal::unary_ufunc_realh(std::trunc, ctx, in, out);
18+
return internal::unary_ufunc_realhbf16(std::trunc, ctx, in, out);
1919
}
2020

2121
} // namespace native

kernels/portable/cpu/pattern/pattern.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ namespace internal {
5959
* and dtype. The function fn specifies the math operation which is applied to
6060
* the input tensor element-wise.
6161
*/
62-
Tensor& unary_ufunc_realh(
62+
Tensor& unary_ufunc_realhbf16(
6363
double (*fn)(double),
6464
KernelRuntimeContext& ctx,
6565
const Tensor& in,

kernels/portable/cpu/pattern/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def define_common_targets():
5252
srcs = [
5353
"unary_ufunc_realhb_to_bool.cpp",
5454
"unary_ufunc_realhbbf16_to_floathbf16.cpp",
55-
"unary_ufunc_realh.cpp",
55+
"unary_ufunc_realhbf16.cpp",
5656
],
5757
exported_headers = [
5858
"pattern.h",

kernels/portable/cpu/pattern/unary_ufunc_realh.cpp renamed to kernels/portable/cpu/pattern/unary_ufunc_realhbf16.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace executor {
1515
namespace native {
1616
namespace internal {
1717

18-
Tensor& unary_ufunc_realh(
18+
Tensor& unary_ufunc_realhbf16(
1919
double (*fn)(double),
2020
KernelRuntimeContext& ctx,
2121
const Tensor& in,
@@ -36,7 +36,7 @@ Tensor& unary_ufunc_realh(
3636
ET_KERNEL_CHECK(
3737
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
3838

39-
ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, __func__, CTYPE, [&] {
39+
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, __func__, CTYPE, [&] {
4040
apply_unary_map_fn(
4141
[fn](const CTYPE val_in) { return static_cast<CTYPE>(fn(val_in)); },
4242
in.const_data_ptr<CTYPE>(),

kernels/test/op_ceil_test.cpp

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,33 +25,28 @@ class OpCeilTest : public OperatorTest {
2525
Tensor& op_ceil_out(const Tensor& self, Tensor& out) {
2626
return torch::executor::aten::ceil_outf(context_, self, out);
2727
}
28-
};
2928

30-
TEST_F(OpCeilTest, SanityCheck) {
31-
TensorFactory<ScalarType::Float> tf;
29+
template <ScalarType DTYPE>
30+
void test_ceil_float_dtype() {
31+
TensorFactory<DTYPE> tf;
3232

33-
Tensor in = tf.make({1, 7}, {-3.0, -2.99, -1.01, 0.0, 1.01, 2.99, 3.0});
34-
Tensor out = tf.zeros({1, 7});
35-
Tensor expected = tf.make({1, 7}, {-3.0, -2.0, -1.0, 0.0, 2.0, 3.0, 3.0});
33+
Tensor in = tf.make({1, 7}, {-3.0, -2.99, -1.01, 0.0, 1.01, 2.99, 3.0});
34+
Tensor out = tf.zeros({1, 7});
35+
Tensor expected = tf.make({1, 7}, {-3.0, -2.0, -1.0, 0.0, 2.0, 3.0, 3.0});
3636

37-
Tensor ret = op_ceil_out(in, out);
37+
Tensor ret = op_ceil_out(in, out);
3838

39-
EXPECT_TENSOR_EQ(out, ret);
40-
EXPECT_TENSOR_EQ(out, expected);
41-
}
39+
EXPECT_TENSOR_EQ(out, ret);
40+
EXPECT_TENSOR_EQ(out, expected);
41+
}
42+
};
4243

43-
TEST_F(OpCeilTest, HalfSupport) {
44+
TEST_F(OpCeilTest, AllFloatDtypeSupport) {
45+
#define TEST_ENTRY(ctype, dtype) test_ceil_float_dtype<ScalarType::dtype>();
4446
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
45-
GTEST_SKIP() << "Test Half support only for ExecuTorch mode";
47+
ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
48+
} else {
49+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
4650
}
47-
TensorFactory<ScalarType::Half> tf;
48-
49-
Tensor in = tf.make({1, 7}, {-3.0, -2.99, -1.01, 0.0, 1.01, 2.99, 3.0});
50-
Tensor out = tf.zeros({1, 7});
51-
Tensor expected = tf.make({1, 7}, {-3.0, -2.0, -1.0, 0.0, 2.0, 3.0, 3.0});
52-
53-
Tensor ret = op_ceil_out(in, out);
54-
55-
EXPECT_TENSOR_EQ(out, ret);
56-
EXPECT_TENSOR_EQ(out, expected);
51+
#undef TEST_ENTRY
5752
}

kernels/test/op_floor_test.cpp

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,33 +25,28 @@ class OpFloorTest : public OperatorTest {
2525
Tensor& op_floor_out(const Tensor& self, Tensor& out) {
2626
return torch::executor::aten::floor_outf(context_, self, out);
2727
}
28-
};
2928

30-
TEST_F(OpFloorTest, SanityCheck) {
31-
TensorFactory<ScalarType::Float> tf;
29+
template <ScalarType DTYPE>
30+
void test_floor_float_dtype() {
31+
TensorFactory<DTYPE> tf;
3232

33-
Tensor in = tf.make({1, 7}, {-3.0, -2.99, -1.01, 0.0, 1.01, 2.99, 3.0});
34-
Tensor out = tf.zeros({1, 7});
35-
Tensor expected = tf.make({1, 7}, {-3.0, -3.0, -2.0, 0.0, 1.0, 2.0, 3.0});
33+
Tensor in = tf.make({1, 7}, {-3.0, -2.99, -1.01, 0.0, 1.01, 2.99, 3.0});
34+
Tensor out = tf.zeros({1, 7});
35+
Tensor expected = tf.make({1, 7}, {-3.0, -3.0, -2.0, 0.0, 1.0, 2.0, 3.0});
3636

37-
Tensor ret = op_floor_out(in, out);
37+
Tensor ret = op_floor_out(in, out);
3838

39-
EXPECT_TENSOR_EQ(out, ret);
40-
EXPECT_TENSOR_EQ(out, expected);
41-
}
39+
EXPECT_TENSOR_EQ(out, ret);
40+
EXPECT_TENSOR_EQ(out, expected);
41+
}
42+
};
4243

43-
TEST_F(OpFloorTest, HalfSupport) {
44+
TEST_F(OpFloorTest, AllFloatDtypeSupport) {
45+
#define TEST_ENTRY(ctype, dtype) test_floor_float_dtype<ScalarType::dtype>();
4446
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
45-
GTEST_SKIP() << "Test Half support only for ExecuTorch mode";
47+
ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
48+
} else {
49+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
4650
}
47-
TensorFactory<ScalarType::Half> tf;
48-
49-
Tensor in = tf.make({1, 7}, {-3.0, -2.99, -1.01, 0.0, 1.01, 2.99, 3.0});
50-
Tensor out = tf.zeros({1, 7});
51-
Tensor expected = tf.make({1, 7}, {-3.0, -3.0, -2.0, 0.0, 1.0, 2.0, 3.0});
52-
53-
Tensor ret = op_floor_out(in, out);
54-
55-
EXPECT_TENSOR_EQ(out, ret);
56-
EXPECT_TENSOR_EQ(out, expected);
51+
#undef TEST_ENTRY
5752
}

kernels/test/op_trunc_test.cpp

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,33 @@ using executorch::aten::Tensor;
2222
using torch::executor::testing::SupportedFeatures;
2323
using torch::executor::testing::TensorFactory;
2424

25-
Tensor& op_trunc_out(const Tensor& a, Tensor& out) {
26-
executorch::runtime::KernelRuntimeContext context{};
27-
return torch::executor::aten::trunc_outf(context, a, out);
28-
}
25+
class OpTruncTest : public OperatorTest {
26+
protected:
27+
Tensor& op_trunc_out(const Tensor& self, Tensor& out) {
28+
return torch::executor::aten::trunc_outf(context_, self, out);
29+
}
30+
31+
template <ScalarType DTYPE>
32+
void test_trunc_float_dtype() {
33+
TensorFactory<DTYPE> tf;
34+
35+
Tensor in = tf.make({1, 6}, {60.5, 16.25, -95.0, -36.125, 19.0, -47.75});
36+
Tensor out = tf.zeros({1, 6});
37+
Tensor expected = tf.make({1, 6}, {60.0, 16.0, -95.0, -36.0, 19.0, -47.0});
38+
39+
Tensor ret = op_trunc_out(in, out);
2940

30-
TEST(OpTruncOutTest, SmokeTest) {
31-
TensorFactory<ScalarType::Double> tfDouble;
41+
EXPECT_TENSOR_EQ(out, ret);
42+
EXPECT_TENSOR_EQ(out, expected);
43+
}
44+
};
3245

33-
Tensor self =
34-
tfDouble.make({1, 6}, {60.5, 16.25, -95.0, -36.125, 19.0, -47.75});
35-
Tensor out = tfDouble.zeros({1, 6});
36-
Tensor out_expected =
37-
tfDouble.make({1, 6}, {60.0, 16.0, -95.0, -36.0, 19.0, -47.0});
38-
op_trunc_out(self, out);
39-
EXPECT_TENSOR_CLOSE(out, out_expected);
46+
TEST_F(OpTruncTest, AllFloatDtypeSupport) {
47+
#define TEST_ENTRY(ctype, dtype) test_trunc_float_dtype<ScalarType::dtype>();
48+
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
49+
ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
50+
} else {
51+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
52+
}
53+
#undef TEST_ENTRY
4054
}

0 commit comments

Comments
 (0)