|
9 | 9 | #pragma once
|
10 | 10 |
|
11 | 11 | #include <ATen/cpu/vec/vec.h>
|
| 12 | +#include <ATen/cpu/vec/vec_n.h> |
12 | 13 | #include <executorch/kernels/optimized/blas/CPUBlas.h>
|
13 | 14 | #include <executorch/kernels/optimized/vec/functional.h>
|
14 | 15 | #include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
|
@@ -362,22 +363,37 @@ using Tensor = ::executorch::aten::Tensor;
|
362 | 363 | template <typename T1, typename T2>
|
363 | 364 | inline void
|
364 | 365 | _exp_reduce_sum_fusion_kernel(T1* a, const int& size, T2* out, T1& val) {
|
365 |
| - auto vec_size = vec::Vectorized<T1>::size(); |
366 |
| - auto vec_max = vec::Vectorized<T1>(val); |
| 366 | + // NOTE: we observed numerics issues with this function when |
| 367 | + // deleting the old executorch::vec and replacing with at::vec |
| 368 | + // here. The major known difference is that executorch::vec was 256 |
| 369 | + // bits wide vs 128 bits for at::vec (and the hardware). Preserving |
| 370 | + // this function's execution width at 256 bits and avoiding |
| 371 | + // vec_reduce_all below removed the issues. |
| 372 | + constexpr auto vec_size = vec::Vectorized<T1>::size() * 2; |
| 373 | + auto vec_max = vec::VectorizedN<T1, 2>(val); |
367 | 374 | T1 tmp_sum = 0;
|
368 |
| - auto vec_tmp_sum = vec::Vectorized<T1>(tmp_sum); |
| 375 | + auto vec_tmp_sum = vec::VectorizedN<T1, 2>(tmp_sum); |
369 | 376 | for (int i = 0; i < vec_size * (size / vec_size); i += vec_size) {
|
370 |
| - auto tmp0 = vec::Vectorized<T1>::loadu(a + i); |
| 377 | + auto tmp0 = vec::VectorizedN<T1, 2>::loadu(a + i); |
371 | 378 | auto tmp1 = tmp0 - vec_max;
|
372 | 379 | // Replace with exp_u20 later
|
373 | 380 | // auto tmp2 = tmp1.exp_u20();
|
374 | 381 | auto tmp2 = tmp1.exp();
|
375 |
| - vec_tmp_sum += tmp2; |
376 |
| - _store(out + i, tmp2); |
| 382 | + vec_tmp_sum = vec_tmp_sum + tmp2; |
| 383 | + tmp2.store(out + i); |
377 | 384 | }
|
378 |
| - tmp_sum = vec::vec_reduce_all<T1>( |
379 |
| - [](vec::Vectorized<T1>& x, vec::Vectorized<T1>& y) { return x + y; }, |
380 |
| - vec_tmp_sum); |
| 385 | + |
| 386 | + __at_align__ T1 vec_tmp_sum_array[vec_size]; |
| 387 | + vec_tmp_sum.store(vec_tmp_sum_array); |
| 388 | + for (const auto i : c10::irange(vec_size)) { |
| 389 | + tmp_sum += vec_tmp_sum_array[i]; |
| 390 | + } |
| 391 | + // See NOTE above; we should replace the scalar reduction above with |
| 392 | + // this reduction (which uses vaddvq_f32 internally), but it changes |
| 393 | + // numerics. |
| 394 | + // tmp_sum = vec::vec_reduce_all<T1>( |
| 395 | + // [](vec::Vectorized<T1>& x, vec::Vectorized<T1>& y) { return x + y; }, |
| 396 | + // vec_tmp_sum); |
381 | 397 | for (int i = vec_size * (size / vec_size); i < size; i++) {
|
382 | 398 | auto tmp0 = a[i];
|
383 | 399 | auto tmp1 = tmp0 - val;
|
|
0 commit comments