Skip to content

Commit cc3974f

Browse files
pytorchbotGithub Executorch
andauthored
Use at::Vectorized in optimized log_softmax
Pull Request resolved: #8382 This should allow us to enable this op in OSS, because Vectorized handles any Sleef issues for us as needed. (I considered going straight to sharing the PyTorch core implementation, but we need parallel_for enabled for that and this improvement is easy enough to make.) Differential Revision: [D69473208](https://our.internmc.facebook.com/intern/diff/D69473208/) ghstack-source-id: 267044107 Co-authored-by: Github Executorch <[email protected]>
1 parent 8d1480b commit cc3974f

File tree

2 files changed

+21
-24
lines changed

2 files changed

+21
-24
lines changed

kernels/optimized/cpu/op_log_softmax.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <cmath>
1515
#include <type_traits>
1616

17+
#include <ATen/cpu/vec/functional.h>
18+
#include <ATen/cpu/vec/vec.h>
1719
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
1820
#include <executorch/runtime/kernel/kernel_includes.h>
1921

@@ -66,30 +68,30 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
6668
}
6769
// calculate sum and exponential in softmax dim
6870
OUT_T temp_sum = 0;
69-
#ifndef __aarch64__
70-
for (auto d = 0; d < dim_size; ++d) {
71-
output_data[d * dim_stride] =
72-
std::exp(input_data[d * dim_stride] - max_input);
73-
temp_sum += output_data[d * dim_stride];
74-
}
75-
#else
71+
using VecOut = at::vec::Vectorized<OUT_T>;
72+
using VecIn = at::vec::Vectorized<IN_T>;
7673
auto d = 0;
77-
for (; d + 4 < dim_size; d += 4) {
74+
static_assert(sizeof(IN_T) == sizeof(OUT_T));
75+
static_assert(
76+
std::is_same_v<OUT_T, float>,
77+
"Below loop actually only supports float.");
78+
const VecIn max_input_vec(max_input);
79+
for (; d + VecOut::size() < dim_size; d += VecOut::size()) {
7880
auto index = d * dim_stride;
79-
float32x4_t in =
80-
vld1q_f32(static_cast<const float*>(&input_data[index]));
81-
float32x4_t out_ =
82-
Sleef_expf4_u10(vsubq_f32(in, vmovq_n_f32(max_input)));
83-
vst1q_f32(static_cast<float*>(&output_data[index]), out_);
81+
auto in = VecIn::loadu(&input_data[index]);
82+
auto out_ = (in - max_input_vec).exp();
83+
out_.store(&output_data[index]);
84+
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)
8485
temp_sum += vaddvq_f32(out_);
86+
#else
87+
temp_sum += at::vec::vec_reduce_all<float>(std::plus<VecOut>(), out_);
88+
#endif
8589
}
86-
8790
for (; d < dim_size; ++d) {
8891
output_data[d * dim_stride] =
8992
std::exp(input_data[d * dim_stride] - max_input);
9093
temp_sum += output_data[d * dim_stride];
9194
}
92-
#endif // __aarch64__
9395

9496
temp_sum = std::log(temp_sum);
9597

kernels/optimized/cpu/targets.bzl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,10 @@ _OPTIMIZED_ATEN_OPS = (
5757
),
5858
op_target(
5959
name = "op_log_softmax",
60-
deps = select({
61-
"DEFAULT": [
62-
"//executorch/kernels/portable/cpu/util:activation_ops_util",
63-
],
64-
"ovr_config//cpu:arm64": [
65-
"//executorch/kernels/portable/cpu/util:activation_ops_util",
66-
"fbsource//third-party/sleef:sleef_arm",
67-
],
68-
}),
60+
deps = [
61+
"//executorch/kernels/portable/cpu/util:activation_ops_util",
62+
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
63+
],
6964
),
7065
op_target(
7166
name = "op_mm",

0 commit comments

Comments
 (0)