Skip to content

Commit 5b4d9bb

Browse files
[Executorch] optimized sigmoid
Pull Request resolved: #6522 basically use exp approximation using sleef instead of std::exp ghstack-source-id: 254026289 @exported-using-ghexport Differential Revision: [D64156864](https://our.internmc.facebook.com/intern/diff/D64156864/) Co-authored-by: Kimish Patel <[email protected]>
1 parent c242a59 commit 5b4d9bb

File tree

5 files changed

+115
-1
lines changed

5 files changed

+115
-1
lines changed

kernels/optimized/cpu/op_sigmoid.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cmath>
10+
11+
#include <executorch/kernels/optimized/vec/functional.h>
12+
#include <executorch/kernels/optimized/vec/vec.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
namespace {
20+
21+
template <typename T>
22+
constexpr bool is_half_or_bf16_v = std::is_same_v<T, exec_aten::Half> ||
23+
std::is_same_v<T, exec_aten::BFloat16>;
24+
25+
template <
26+
typename CTYPE_IN,
27+
typename CTYPE_OUT,
28+
typename std::enable_if<
29+
std::is_same_v<CTYPE_IN, CTYPE_OUT> && !is_half_or_bf16_v<CTYPE_IN> &&
30+
!is_half_or_bf16_v<CTYPE_OUT>,
31+
int>::type = 0>
32+
void sigmoid_data(
33+
const CTYPE_IN* in_data,
34+
const size_t numel,
35+
CTYPE_OUT* out_data) {
36+
using Vec = executorch::vec::Vectorized<CTYPE_IN>;
37+
executorch::vec::map<CTYPE_IN>(
38+
[](Vec x) {
39+
auto one_plus_exp = x.neg().exp() + Vec(static_cast<CTYPE_IN>(1.0));
40+
return one_plus_exp.reciprocal();
41+
},
42+
out_data,
43+
in_data,
44+
numel);
45+
}
46+
47+
template <
48+
typename CTYPE_IN,
49+
typename CTYPE_OUT,
50+
typename std::enable_if<
51+
!std::is_same_v<CTYPE_IN, CTYPE_OUT> || is_half_or_bf16_v<CTYPE_IN> ||
52+
is_half_or_bf16_v<CTYPE_OUT>,
53+
int>::type = 0>
54+
void sigmoid_data(
55+
const CTYPE_IN* in_data,
56+
const size_t numel,
57+
CTYPE_OUT* out_data) {
58+
for (size_t i = 0; i < numel; i++) {
59+
CTYPE_OUT xi = static_cast<CTYPE_OUT>(in_data[i]);
60+
out_data[i] = (1.0f / (1.0f + std::exp(-xi)));
61+
}
62+
}
63+
64+
} // namespace
65+
66+
using Tensor = exec_aten::Tensor;
67+
68+
Tensor&
69+
opt_sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
70+
(void)ctx;
71+
72+
ET_KERNEL_CHECK(
73+
ctx, in.scalar_type() != ScalarType::Bool, InvalidArgument, out);
74+
ET_KERNEL_CHECK(ctx, tensor_is_floating_type(out), InvalidArgument, out);
75+
76+
ET_KERNEL_CHECK(
77+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
78+
79+
// Resize for dynamic shape
80+
ET_KERNEL_CHECK_MSG(
81+
ctx,
82+
resize_tensor(out, in.sizes()) == Error::Ok,
83+
InvalidArgument,
84+
out,
85+
"Failed to resize output tensor.");
86+
87+
ScalarType in_type = in.scalar_type();
88+
ScalarType out_type = out.scalar_type();
89+
ET_SWITCH_REALHB_TYPES(in_type, ctx, "sigmoid.out", CTYPE_IN, [&]() {
90+
ET_SWITCH_FLOATH_TYPES(out_type, ctx, "sigmoid.out", CTYPE_OUT, [&]() {
91+
sigmoid_data<CTYPE_IN, CTYPE_OUT>(
92+
in.const_data_ptr<CTYPE_IN>(),
93+
in.numel(),
94+
out.mutable_data_ptr<CTYPE_OUT>());
95+
});
96+
});
97+
98+
return out;
99+
}
100+
101+
} // namespace native
102+
} // namespace executor
103+
} // namespace torch

kernels/optimized/cpu/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ _OPTIMIZED_ATEN_OPS = (
2525
],
2626
),
2727
op_target(name = "op_exp"),
28+
op_target(name = "op_sigmoid"),
2829
op_target(
2930
name = "op_gelu",
3031
deps = select({

kernels/optimized/optimized-oss.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@
3535
- arg_meta: null
3636
kernel_name: torch::executor::opt_exp_out
3737

38+
- op: sigmoid.out
39+
kernels:
40+
- arg_meta: null
41+
kernel_name: torch::executor::opt_sigmoid_out
42+
3843
- op: le.Scalar_out
3944
kernels:
4045
- arg_meta: null

kernels/optimized/optimized.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@
3737
- arg_meta: null
3838
kernel_name: torch::executor::opt_exp_out
3939

40+
- op: sigmoid.out
41+
kernels:
42+
- arg_meta: null
43+
kernel_name: torch::executor::opt_sigmoid_out
44+
4045
- op: gelu.out
4146
kernels:
4247
- arg_meta: null

kernels/test/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def define_common_targets():
297297
_common_op_test("op_scatter_add_test", ["aten", "portable"])
298298
_common_op_test("op_select_scatter_test", ["aten", "portable"])
299299
_common_op_test("op_select_copy_test", ["aten", "portable"])
300-
_common_op_test("op_sigmoid_test", ["aten", "portable"])
300+
_common_op_test("op_sigmoid_test", ["aten", "portable", "optimized"])
301301
_common_op_test("op_sign_test", ["aten", "portable"])
302302
_common_op_test("op_sin_test", ["aten", "portable"])
303303
_common_op_test("op_sinh_test", ["aten", "portable"])

0 commit comments

Comments
 (0)