|
14 | 14 | #include <cmath>
|
15 | 15 | #include <type_traits>
|
16 | 16 |
|
| 17 | +#include <ATen/cpu/vec/functional.h> |
| 18 | +#include <ATen/cpu/vec/vec.h> |
17 | 19 | #include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
|
18 | 20 | #include <executorch/runtime/kernel/kernel_includes.h>
|
19 | 21 |
|
@@ -66,30 +68,30 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
|
66 | 68 | }
|
67 | 69 | // calculate sum and exponential in softmax dim
|
68 | 70 | OUT_T temp_sum = 0;
|
69 |
| -#ifndef __aarch64__ |
70 |
| - for (auto d = 0; d < dim_size; ++d) { |
71 |
| - output_data[d * dim_stride] = |
72 |
| - std::exp(input_data[d * dim_stride] - max_input); |
73 |
| - temp_sum += output_data[d * dim_stride]; |
74 |
| - } |
75 |
| -#else |
| 71 | + using VecOut = at::vec::Vectorized<OUT_T>; |
| 72 | + using VecIn = at::vec::Vectorized<IN_T>; |
76 | 73 | auto d = 0;
|
77 |
| - for (; d + 4 < dim_size; d += 4) { |
| 74 | + static_assert(sizeof(IN_T) == sizeof(OUT_T)); |
| 75 | + static_assert( |
| 76 | + std::is_same_v<OUT_T, float>, |
| 77 | + "Below loop actually only supports float."); |
| 78 | + const VecIn max_input_vec(max_input); |
| 79 | + for (; d + VecOut::size() < dim_size; d += VecOut::size()) { |
78 | 80 | auto index = d * dim_stride;
|
79 |
| - float32x4_t in = |
80 |
| - vld1q_f32(static_cast<const float*>(&input_data[index])); |
81 |
| - float32x4_t out_ = |
82 |
| - Sleef_expf4_u10(vsubq_f32(in, vmovq_n_f32(max_input))); |
83 |
| - vst1q_f32(static_cast<float*>(&output_data[index]), out_); |
| 81 | + auto in = VecIn::loadu(&input_data[index]); |
| 82 | + auto out_ = (in - max_input_vec).exp(); |
| 83 | + out_.store(&output_data[index]); |
| 84 | +#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) |
84 | 85 | temp_sum += vaddvq_f32(out_);
|
| 86 | +#else |
| 87 | + temp_sum += at::vec::vec_reduce_all<float>(std::plus<VecOut>(), out_); |
| 88 | +#endif |
85 | 89 | }
|
86 |
| - |
87 | 90 | for (; d < dim_size; ++d) {
|
88 | 91 | output_data[d * dim_stride] =
|
89 | 92 | std::exp(input_data[d * dim_stride] - max_input);
|
90 | 93 | temp_sum += output_data[d * dim_stride];
|
91 | 94 | }
|
92 |
| -#endif // __aarch64__ |
93 | 95 |
|
94 | 96 | temp_sum = std::log(temp_sum);
|
95 | 97 |
|
|
0 commit comments