Skip to content

Commit 83fb692

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Pattern: unary realb to float (exp & log)
Reviewed By: digantdesai Differential Revision: D47118015 fbshipit-source-id: 03a889582cc669c5c6e2817db97e78e8a1a51775
1 parent 2b79873 commit 83fb692

File tree

5 files changed

+23
-87
lines changed

5 files changed

+23
-87
lines changed

kernels/optimized/cpu/op_exp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ Tensor& opt_exp_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
5656
auto error = resize_tensor(out, in.sizes());
5757
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
5858

59-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "exp", CTYPE_IN, [&] {
59+
ET_SWITCH_REAL_TYPES_AND(Bool, in.scalar_type(), ctx, "exp", CTYPE_IN, [&] {
6060
ET_SWITCH_FLOAT_TYPES(out.scalar_type(), ctx, "exp", CTYPE_OUT, [&] {
6161
exp_data<CTYPE_IN, CTYPE_OUT>(
6262
in.const_data_ptr<CTYPE_IN>(),

kernels/portable/cpu/op_exp.cpp

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,15 @@
1-
// Copyright (c) Meta Platforms, Inc. and affiliates.
2-
#include <cmath>
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
32

43
#include <executorch/kernels/kernel_includes.h>
5-
#include <executorch/kernels/portable/cpu/util/functional_util.h>
4+
#include <executorch/kernels/portable/cpu/pattern/pattern.h>
5+
#include <cmath>
66

77
namespace torch {
88
namespace executor {
99
namespace native {
1010

11-
using exec_aten::Tensor;
12-
1311
Tensor& exp_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
14-
(void)ctx;
15-
16-
// Resize for dynamic shape
17-
auto error = resize_tensor(out, in.sizes());
18-
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
19-
ET_CHECK_SAME_SHAPE2(in, out);
20-
21-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "exp", CTYPE_IN, [&] {
22-
ET_SWITCH_FLOAT_TYPES(out.scalar_type(), ctx, "exp", CTYPE_OUT, [&] {
23-
apply_unary_map_fn(
24-
[](const CTYPE_IN val_in) {
25-
CTYPE_OUT xi = static_cast<CTYPE_OUT>(val_in);
26-
return std::exp(xi);
27-
},
28-
in.const_data_ptr<CTYPE_IN>(),
29-
out.mutable_data_ptr<CTYPE_OUT>(),
30-
in.numel());
31-
});
32-
});
33-
34-
return out;
12+
return internal::unary_ufunc_realb_to_float(std::exp, ctx, in, out);
3513
}
3614

3715
} // namespace native

kernels/portable/cpu/op_log.cpp

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,15 @@
1-
// Copyright (c) Meta Platforms, Inc. and affiliates.
2-
3-
#include <cmath>
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
42

53
#include <executorch/kernels/kernel_includes.h>
6-
#include <executorch/kernels/portable/cpu/util/functional_util.h>
4+
#include <executorch/kernels/portable/cpu/pattern/pattern.h>
5+
#include <cmath>
76

87
namespace torch {
98
namespace executor {
109
namespace native {
1110

12-
using exec_aten::Tensor;
13-
1411
Tensor& log_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
15-
(void)ctx;
16-
17-
// Resize for dynamic shape
18-
auto error = resize_tensor(out, in.sizes());
19-
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
20-
ET_CHECK_SAME_SHAPE2(in, out);
21-
22-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "log", CTYPE_IN, [&] {
23-
ET_SWITCH_FLOAT_TYPES(out.scalar_type(), ctx, "log", CTYPE_OUT, [&] {
24-
apply_unary_map_fn(
25-
[](const CTYPE_IN val_in) {
26-
CTYPE_OUT xi = static_cast<CTYPE_OUT>(val_in);
27-
ET_CHECK_MSG(xi > 0.0, "input must be greater than 0.");
28-
return static_cast<CTYPE_OUT>(log(xi));
29-
},
30-
in.const_data_ptr<CTYPE_IN>(),
31-
out.mutable_data_ptr<CTYPE_OUT>(),
32-
in.numel());
33-
});
34-
});
35-
36-
return out;
12+
return internal::unary_ufunc_realb_to_float(std::log, ctx, in, out);
3713
}
3814

3915
} // namespace native

kernels/test/op_exp_test.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -72,24 +72,18 @@ TEST(OpExpOutKernelTest, AllFloatInputDoubleOutputSupport) {
7272
#undef TEST_ENTRY
7373
}
7474

75-
TEST(OpExpOutKernelTest, UnhandledInputDtypeDies) {
76-
if (SupportedFeatures::get()->is_aten) {
77-
GTEST_SKIP() << "ATen kernel can handle bool dtype";
78-
}
79-
80-
// _exp_out() doesn't handle Bool as input.
75+
TEST(OpExpOutKernelTest, HandleBoolInput) {
76+
// _exp_out() handles Bool as input.
8177
TensorFactory<ScalarType::Bool> tf_bool;
8278
TensorFactory<ScalarType::Float> tf_float;
8379

84-
const std::vector<int32_t> sizes = {2, 2};
85-
Tensor a = tf_bool.make(sizes, /*data=*/{false, true, false, true});
80+
const std::vector<int32_t> sizes = {1, 2};
8681

87-
// Destination for the log
82+
Tensor a = tf_bool.make(sizes, /*data=*/{true, false});
8883
Tensor out = tf_float.zeros(sizes);
84+
Tensor res = tf_float.make(sizes, /*data=*/{2.718282, 1});
8985

90-
// Boolean tensor should cause an assertion and kill the
91-
// test process.
92-
ET_EXPECT_KERNEL_FAILURE(_exp_out(a, out));
86+
EXPECT_TENSOR_CLOSE(_exp_out(a, out), res);
9387
}
9488

9589
// Mismatched shape tests.

kernels/test/op_log_test.cpp

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,12 @@ void test__log_out() {
2929

3030
const std::vector<int32_t> sizes = {2, 2};
3131

32-
// Invalid input zero should die
3332
Tensor out = tf_out.zeros(sizes);
34-
if (SupportedFeatures::get()->is_aten) {
35-
// However, ATen can handle when input is zero
36-
} else {
37-
ET_EXPECT_KERNEL_FAILURE(_log_out(tf.zeros(sizes), out));
38-
}
3933

4034
// Valid input should give the expected output
41-
_log_out(tf.make(sizes, /*data=*/{1, 2, 4, 8}), out);
35+
_log_out(tf.make(sizes, /*data=*/{0, 1, 2, 4}), out);
4236
EXPECT_TENSOR_CLOSE(
43-
out, tf_out.make(sizes, /*data=*/{0, 0.693147, 1.386294, 2.079441}));
37+
out, tf_out.make(sizes, /*data=*/{-INFINITY, 0, 0.693147, 1.386294}));
4438
}
4539

4640
TEST(OpLogOutKernelTest, AllRealInputFloatOutputSupport) {
@@ -57,24 +51,18 @@ TEST(OpLogOutKernelTest, AllRealInputDoubleOutputSupport) {
5751
#undef TEST_ENTRY
5852
}
5953

60-
TEST(OpLogOutKernelTest, UnhandledInputDtypeDies) {
61-
if (SupportedFeatures::get()->is_aten) {
62-
GTEST_SKIP() << "ATen kernel can handle bool dtype";
63-
}
64-
65-
// _log_out() doesn't handle Bool as input.
54+
TEST(OpLogOutKernelTest, HandleBoolInput) {
55+
// _log_out() handles Bool as input.
6656
TensorFactory<ScalarType::Bool> tf_bool;
6757
TensorFactory<ScalarType::Float> tf_float;
6858

69-
const std::vector<int32_t> sizes = {2, 2};
70-
Tensor a = tf_bool.make(sizes, /*data=*/{false, true, false, true});
59+
const std::vector<int32_t> sizes = {1, 2};
7160

72-
// Destination for the log
61+
Tensor a = tf_bool.make(sizes, /*data=*/{true, false});
7362
Tensor out = tf_float.zeros(sizes);
63+
Tensor res = tf_float.make(sizes, /*data=*/{0, -INFINITY});
7464

75-
// Boolean tensor should cause an assertion and kill the
76-
// test process.
77-
ET_EXPECT_KERNEL_FAILURE(_log_out(a, out));
65+
EXPECT_TENSOR_EQ(_log_out(a, out), res);
7866
}
7967

8068
// Mismatched shape tests.

0 commit comments

Comments
 (0)