@@ -17,13 +17,17 @@ namespace executor {
17
17
namespace native {
18
18
19
19
namespace {
20
+
21
+ template <typename T>
22
+ constexpr bool is_half_or_bf16_v = std::is_same_v<T, exec_aten::Half> ||
23
+ std::is_same_v<T, exec_aten::BFloat16>;
24
+
20
25
template <
21
26
typename CTYPE_IN,
22
27
typename CTYPE_OUT,
23
28
typename std::enable_if<
24
- std::is_same_v<CTYPE_IN, CTYPE_OUT> &&
25
- !std::is_same_v<CTYPE_IN, exec_aten::Half> &&
26
- !std::is_same_v<CTYPE_OUT, exec_aten::BFloat16>,
29
+ std::is_same_v<CTYPE_IN, CTYPE_OUT> && !is_half_or_bf16_v<CTYPE_IN> &&
30
+ !is_half_or_bf16_v<CTYPE_OUT>,
27
31
int >::type = 0 >
28
32
void sigmoid_data (
29
33
const CTYPE_IN* in_data,
@@ -32,7 +36,7 @@ void sigmoid_data(
32
36
using Vec = executorch::vec::Vectorized<CTYPE_IN>;
33
37
executorch::vec::map<CTYPE_IN>(
34
38
[](Vec x) {
35
- auto one_plus_exp = x.neg ().exp () + Vec (1.0 );
39
+ auto one_plus_exp = x.neg ().exp () + Vec (static_cast <CTYPE_IN>( 1.0 ) );
36
40
return one_plus_exp.reciprocal ();
37
41
},
38
42
out_data,
@@ -44,19 +48,16 @@ template <
44
48
typename CTYPE_IN,
45
49
typename CTYPE_OUT,
46
50
typename std::enable_if<
47
- !std::is_same_v<CTYPE_IN, CTYPE_OUT> ||
48
- std::is_same_v<CTYPE_IN, exec_aten::Half> ||
49
- std::is_same_v<CTYPE_IN, exec_aten::BFloat16> ||
50
- std::is_same_v<CTYPE_OUT, exec_aten::Half> ||
51
- std::is_same_v<CTYPE_OUT, exec_aten::BFloat16>,
51
+ !std::is_same_v<CTYPE_IN, CTYPE_OUT> || is_half_or_bf16_v<CTYPE_IN> ||
52
+ is_half_or_bf16_v<CTYPE_OUT>,
52
53
int >::type = 0 >
53
54
void sigmoid_data (
54
55
const CTYPE_IN* in_data,
55
56
const size_t numel,
56
57
CTYPE_OUT* out_data) {
57
58
for (size_t i = 0 ; i < numel; i++) {
58
59
CTYPE_OUT xi = static_cast <CTYPE_OUT>(in_data[i]);
59
- out_data[i] = (1.0 / (1.0 + std::exp (-xi)));
60
+ out_data[i] = (1 .0f / (1 .0f + std::exp (-xi)));
60
61
}
61
62
}
62
63
0 commit comments