Skip to content

Commit b612e5b

Browse files
committed
fix numerics for internal test on "Remove ExecuTorch copy of Vectorized"
All uses are outside ExecuTorch core, so we can just use ATen Vectorized. Differential Revision: [D66396016](https://our.internmc.facebook.com/intern/diff/D66396016/) [ghstack-poisoned]
2 parents 6133476 + 6c9d27e commit b612e5b

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

extension/llm/custom_ops/op_sdpa_impl.h

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <ATen/cpu/vec/vec.h>
12+
#include <ATen/cpu/vec/vec_n.h>
1213
#include <executorch/kernels/optimized/blas/CPUBlas.h>
1314
#include <executorch/kernels/optimized/vec/functional.h>
1415
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
@@ -362,22 +363,37 @@ using Tensor = ::executorch::aten::Tensor;
362363
template <typename T1, typename T2>
363364
inline void
364365
_exp_reduce_sum_fusion_kernel(T1* a, const int& size, T2* out, T1& val) {
365-
auto vec_size = vec::Vectorized<T1>::size();
366-
auto vec_max = vec::Vectorized<T1>(val);
366+
// NOTE: we observed numerics issues with this function when
367+
// deleting the old executorch::vec and replacing with at::vec
368+
// here. The major known difference is that executorch::vec was 256
369+
// bits wide vs 128 bits for at::vec (and the hardware). Preserving
370+
// this function's execution width at 256 bits and avoiding
371+
// vec_reduce_all below removed the issues.
372+
constexpr auto vec_size = vec::Vectorized<T1>::size() * 2;
373+
auto vec_max = vec::VectorizedN<T1, 2>(val);
367374
T1 tmp_sum = 0;
368-
auto vec_tmp_sum = vec::Vectorized<T1>(tmp_sum);
375+
auto vec_tmp_sum = vec::VectorizedN<T1, 2>(tmp_sum);
369376
for (int i = 0; i < vec_size * (size / vec_size); i += vec_size) {
370-
auto tmp0 = vec::Vectorized<T1>::loadu(a + i);
377+
auto tmp0 = vec::VectorizedN<T1, 2>::loadu(a + i);
371378
auto tmp1 = tmp0 - vec_max;
372379
// Replace with exp_u20 later
373380
// auto tmp2 = tmp1.exp_u20();
374381
auto tmp2 = tmp1.exp();
375-
vec_tmp_sum += tmp2;
376-
_store(out + i, tmp2);
382+
vec_tmp_sum = vec_tmp_sum + tmp2;
383+
tmp2.store(out + i);
377384
}
378-
tmp_sum = vec::vec_reduce_all<T1>(
379-
[](vec::Vectorized<T1>& x, vec::Vectorized<T1>& y) { return x + y; },
380-
vec_tmp_sum);
385+
386+
__at_align__ T1 vec_tmp_sum_array[vec_size];
387+
vec_tmp_sum.store(vec_tmp_sum_array);
388+
for (const auto i : c10::irange(vec_size)) {
389+
tmp_sum += vec_tmp_sum_array[i];
390+
}
391+
// See NOTE above; we should replace the scalar reduction above with
392+
// this reduction (which uses vaddvq_f32 internally), but it changes
393+
// numerics.
394+
// tmp_sum = vec::vec_reduce_all<T1>(
395+
// [](vec::Vectorized<T1>& x, vec::Vectorized<T1>& y) { return x + y; },
396+
// vec_tmp_sum);
381397
for (int i = vec_size * (size / vec_size); i < size; i++) {
382398
auto tmp0 = a[i];
383399
auto tmp1 = tmp0 - val;

0 commit comments

Comments
 (0)