Skip to content

Commit ccd2fe8

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Make sdpa_with_kv_cache thread parallel (#2501)
Summary: Pull Request resolved: #2501 ghstack-source-id: 219444232 Reviewed By: digantdesai Differential Revision: D55047311 fbshipit-source-id: 736eafc944ca575497f67b5af793463dcb573cdd
1 parent 08733f0 commit ccd2fe8

File tree

5 files changed

+169
-154
lines changed

5 files changed

+169
-154
lines changed

examples/models/llama2/custom_ops/op_sdpa.cpp

Lines changed: 146 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,22 @@
1717

1818
#include <array>
1919

20+
#ifdef ET_USE_THREADPOOL
21+
#include <executorch/backends/xnnpack/threadpool/threadpool.h>
22+
#include <executorch/extension/parallel/thread_parallel.h>
23+
#endif
24+
2025
namespace torch {
2126
namespace executor {
27+
2228
namespace native {
2329

2430
namespace util {
2531

2632
constexpr size_t kKVDim = 4;
2733

2834
template <typename T>
29-
inline void _store(T* dst, executorch::vec::Vectorized<T> src) {
35+
inline void _store(T* dst, ::executorch::vec::Vectorized<T> src) {
3036
src.store(dst);
3137
}
3238

@@ -38,19 +44,6 @@ inline void _store(::Half* dst, at::vec::Vectorized<float> src) {
3844
}
3945
*/
4046

41-
template <class F>
42-
inline void parallel_for(
43-
const int64_t begin,
44-
const int64_t end,
45-
const int64_t grain_size,
46-
const F& f) {
47-
for (int64_t i = begin; i < end; i += grain_size) {
48-
int64_t task_begin = i;
49-
int64_t task_end = std::min(task_begin + grain_size, end);
50-
f(task_begin, task_end);
51-
}
52-
}
53-
5447
template <typename T>
5548
inline T data_index_init(T offset) {
5649
return offset;
@@ -83,7 +76,7 @@ inline double calculate_scale(const Tensor& query, optional<double> scale) {
8376
}
8477

8578
} // namespace util
86-
namespace vec = executorch::vec;
79+
namespace vec = ::executorch::vec;
8780
using Tensor = exec_aten::Tensor;
8881

8982
namespace {
@@ -310,8 +303,12 @@ void cpu_flash_attention(
310303
int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
311304
int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
312305
int64_t qSlice = (qSize - 1) / qSplitSize + 1;
313-
// int64_t num_thread = at::get_num_threads();
314-
int64_t num_thread = 1; // at::get_num_threads();
306+
#ifdef ET_USE_THREADPOOL
307+
int64_t num_thread =
308+
torch::executorch::threadpool::get_threadpool()->get_thread_count();
309+
#else
310+
int64_t num_thread = 1;
311+
#endif
315312

316313
// const auto dtype = query.scalar_type();
317314
// Following will be revisited in the future
@@ -346,149 +343,146 @@ void cpu_flash_attention(
346343
scalar_t* buf_reduced_data =
347344
is_reduced_type ? reinterpret_cast<scalar_t*>(buf_reduced) : nullptr;
348345

349-
util::parallel_for(
350-
0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) {
351-
int64_t i = 0, j = 0, k = 0;
352-
util::data_index_init(begin, i, batchSize, j, num_head, k, qSlice);
353-
int ompIdx = 0; // at::get_thread_num();
354-
accum_t* buf_ptr = buf_data + ompIdx * size_per_thread;
355-
accum_t* qk_data = buf_ptr;
356-
accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize;
357-
accum_t* qk_sum_data = qk_max_data + qSplitSize;
358-
accum_t* dst_data = qk_sum_data + qSplitSize;
359-
scalar_t* qk_reduced_data = is_reduced_type
360-
? buf_reduced_data + ompIdx * qSplitSize * kvSplitSize
361-
: nullptr;
362-
363-
for (int64_t z = begin; z < end; z++) {
364-
int64_t m = k * qSplitSize;
365-
int64_t qBlockSize = std::min(qSplitSize, qSize - m);
366-
// Initialize max and sum
367-
fill_stub(
368-
qk_max_data,
369-
-std::numeric_limits<accum_t>::infinity(),
370-
qBlockSize);
371-
int64_t num_keys =
372-
is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
373-
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
374-
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
375-
// Calculate scale * q @ k.T
346+
auto compute_lambda = [&](int64_t begin, int64_t end) {
347+
int64_t i = 0, j = 0, k = 0;
348+
util::data_index_init(begin, i, batchSize, j, num_head, k, qSlice);
349+
int ompIdx = torch::executor::get_thread_num();
350+
accum_t* buf_ptr = buf_data + ompIdx * size_per_thread;
351+
accum_t* qk_data = buf_ptr;
352+
accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize;
353+
accum_t* qk_sum_data = qk_max_data + qSplitSize;
354+
accum_t* dst_data = qk_sum_data + qSplitSize;
355+
scalar_t* qk_reduced_data = is_reduced_type
356+
? buf_reduced_data + ompIdx * qSplitSize * kvSplitSize
357+
: nullptr;
358+
359+
for (int64_t z = begin; z < end; z++) {
360+
int64_t m = k * qSplitSize;
361+
int64_t qBlockSize = std::min(qSplitSize, qSize - m);
362+
// Initialize max and sum
363+
fill_stub(
364+
qk_max_data, -std::numeric_limits<accum_t>::infinity(), qBlockSize);
365+
int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
366+
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
367+
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
368+
// Calculate scale * q @ k.T
369+
fill_stub(qk_data, static_cast<accum_t>(0), qSplitSize * kvSplitSize);
370+
::executorch::cpublas::gemm(
371+
::executorch::cpublas::TransposeType::Transpose,
372+
::executorch::cpublas::TransposeType::NoTranspose,
373+
kvBlockSize,
374+
qBlockSize,
375+
headSize,
376+
static_cast<accum_t>(1),
377+
k_data + i * kStrideB + j * kStrideH + n * kStrideN,
378+
kStrideN,
379+
q_data + i * qStrideB + j * qStrideH + m * qStrideM,
380+
qStrideM,
381+
static_cast<accum_t>(0),
382+
qk_data,
383+
kvBlockSize);
384+
// Apply causal mask, fill unused with -inf
385+
if (is_causal && num_keys - n <= kvSplitSize) {
386+
for (int32_t row = 0; row < qBlockSize; ++row) {
387+
int64_t last_col = m + row - n;
388+
accum_t* row_ptr = qk_data + row * kvBlockSize;
376389
fill_stub(
377-
qk_data, static_cast<accum_t>(0), qSplitSize * kvSplitSize);
378-
executorch::cpublas::gemm(
379-
executorch::cpublas::TransposeType::Transpose,
380-
executorch::cpublas::TransposeType::NoTranspose,
381-
kvBlockSize,
382-
qBlockSize,
383-
headSize,
384-
static_cast<accum_t>(1),
385-
k_data + i * kStrideB + j * kStrideH + n * kStrideN,
386-
kStrideN,
387-
q_data + i * qStrideB + j * qStrideH + m * qStrideM,
388-
qStrideM,
389-
static_cast<accum_t>(0),
390-
qk_data,
390+
row_ptr + last_col + 1,
391+
-std::numeric_limits<accum_t>::infinity(),
392+
kvBlockSize - last_col - 1);
393+
}
394+
}
395+
// Update attention weights with attention mask
396+
// And apply scaling factor
397+
// qk <- qk * scaling + attn_mask
398+
if (has_attn_mask) {
399+
for (int64_t row = 0; row < qBlockSize; ++row) {
400+
vec::map2<accum_t>(
401+
[scaling_factor](Vec x, Vec y) {
402+
return x * Vec(scaling_factor) + y;
403+
},
404+
qk_data + row * kvBlockSize,
405+
qk_data + row * kvBlockSize,
406+
mask_data + i * mStrideB + j * mStrideH + (m + row) * mStrideM +
407+
n,
391408
kvBlockSize);
392-
// Apply causal mask, fill unused with -inf
393-
if (is_causal && num_keys - n <= kvSplitSize) {
394-
for (int32_t row = 0; row < qBlockSize; ++row) {
395-
int64_t last_col = m + row - n;
396-
accum_t* row_ptr = qk_data + row * kvBlockSize;
397-
fill_stub(
398-
row_ptr + last_col + 1,
399-
-std::numeric_limits<accum_t>::infinity(),
400-
kvBlockSize - last_col - 1);
401-
}
402-
}
403-
// Update attention weights with attention mask
404-
// And apply scaling factor
405-
// qk <- qk * scaling + attn_mask
406-
if (has_attn_mask) {
407-
for (int64_t row = 0; row < qBlockSize; ++row) {
408-
vec::map2<accum_t>(
409-
[scaling_factor](Vec x, Vec y) {
410-
return x * Vec(scaling_factor) + y;
411-
},
412-
qk_data + row * kvBlockSize,
413-
qk_data + row * kvBlockSize,
414-
mask_data + i * mStrideB + j * mStrideH +
415-
(m + row) * mStrideM + n,
416-
kvBlockSize);
417-
}
418-
}
419-
// Update coefficients with Softmax
420-
accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0;
421-
for (int64_t row = 0; row < qBlockSize; ++row) {
422-
if (has_attn_mask) {
423-
// max per row
424-
tmp_max = vec::reduce_all<accum_t>(
425-
[](Vec& x, Vec& y) { return vec::maximum(x, y); },
426-
qk_data + row * kvBlockSize,
427-
kvBlockSize);
428-
} else {
429-
// apply scaling factor and max per row in fusion
430-
_mul_reduce_max_fusion_kernel(
431-
qk_data + row * kvBlockSize,
432-
scaling_factor,
433-
kvBlockSize,
434-
qk_data + row * kvBlockSize,
435-
tmp_max);
436-
}
437-
tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
438-
// qk <- exp(qk - max) and sum per row
439-
tmp_sum = tmp_max;
440-
_exp_reduce_sum_fusion_kernel(
441-
qk_data + row * kvBlockSize,
442-
kvBlockSize,
443-
conditional_data_ptr(qk_data, qk_reduced_data) +
444-
row * kvBlockSize,
445-
tmp_sum);
446-
// exp_tmp <- exp(max[row] - max)
447-
exp_tmp = std::exp(qk_max_data[row] - tmp_max);
448-
// sum[row] <- sum + exp_tmp * sum[row]
449-
qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
450-
// max[row] <- max
451-
qk_max_data[row] = tmp_max;
452-
// dst <- dst * exp_tmp
453-
if (n > 0) {
454-
vec::map<accum_t>(
455-
[exp_tmp](Vec x) { return x * Vec(exp_tmp); },
456-
dst_data + row * headSize,
457-
dst_data + row * headSize,
458-
headSize);
459-
}
460-
}
461-
// Calculate Softmax(q @ k.T) @ v
462-
executorch::cpublas::gemm(
463-
executorch::cpublas::TransposeType::NoTranspose,
464-
executorch::cpublas::TransposeType::NoTranspose,
465-
headSize,
466-
qBlockSize,
467-
kvBlockSize,
468-
static_cast<accum_t>(1),
469-
v_data + i * vStrideB + j * vStrideH + n * vStrideN,
470-
vStrideN,
471-
conditional_data_ptr(qk_data, qk_reduced_data),
409+
}
410+
}
411+
// Update coefficients with Softmax
412+
accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0;
413+
for (int64_t row = 0; row < qBlockSize; ++row) {
414+
if (has_attn_mask) {
415+
// max per row
416+
tmp_max = vec::reduce_all<accum_t>(
417+
[](Vec& x, Vec& y) { return vec::maximum(x, y); },
418+
qk_data + row * kvBlockSize,
419+
kvBlockSize);
420+
} else {
421+
// apply scaling factor and max per row in fusion
422+
_mul_reduce_max_fusion_kernel(
423+
qk_data + row * kvBlockSize,
424+
scaling_factor,
472425
kvBlockSize,
473-
n == 0 ? static_cast<accum_t>(0) : static_cast<accum_t>(1),
474-
dst_data,
475-
headSize);
426+
qk_data + row * kvBlockSize,
427+
tmp_max);
476428
}
477-
// dst <- dst / sum[row]
478-
// reorder MHA output with strides
479-
for (int64_t row = 0; row < qBlockSize; ++row) {
480-
accum_t sum_reciprocal = 1 / qk_sum_data[row];
481-
vec::map<scalar_t>(
482-
[sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },
483-
out_data + i * oStrideB + j * oStrideH + m * oStrideM +
484-
row * oStrideM,
429+
tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
430+
// qk <- exp(qk - max) and sum per row
431+
tmp_sum = tmp_max;
432+
_exp_reduce_sum_fusion_kernel(
433+
qk_data + row * kvBlockSize,
434+
kvBlockSize,
435+
conditional_data_ptr(qk_data, qk_reduced_data) +
436+
row * kvBlockSize,
437+
tmp_sum);
438+
// exp_tmp <- exp(max[row] - max)
439+
exp_tmp = std::exp(qk_max_data[row] - tmp_max);
440+
// sum[row] <- sum + exp_tmp * sum[row]
441+
qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
442+
// max[row] <- max
443+
qk_max_data[row] = tmp_max;
444+
// dst <- dst * exp_tmp
445+
if (n > 0) {
446+
vec::map<accum_t>(
447+
[exp_tmp](Vec x) { return x * Vec(exp_tmp); },
448+
dst_data + row * headSize,
485449
dst_data + row * headSize,
486450
headSize);
487451
}
488-
// Move to the next query
489-
util::data_index_step(i, batchSize, j, num_head, k, qSlice);
490452
}
491-
});
453+
// Calculate Softmax(q @ k.T) @ v
454+
::executorch::cpublas::gemm(
455+
::executorch::cpublas::TransposeType::NoTranspose,
456+
::executorch::cpublas::TransposeType::NoTranspose,
457+
headSize,
458+
qBlockSize,
459+
kvBlockSize,
460+
static_cast<accum_t>(1),
461+
v_data + i * vStrideB + j * vStrideH + n * vStrideN,
462+
vStrideN,
463+
conditional_data_ptr(qk_data, qk_reduced_data),
464+
kvBlockSize,
465+
n == 0 ? static_cast<accum_t>(0) : static_cast<accum_t>(1),
466+
dst_data,
467+
headSize);
468+
}
469+
// dst <- dst / sum[row]
470+
// reorder MHA output with strides
471+
for (int64_t row = 0; row < qBlockSize; ++row) {
472+
accum_t sum_reciprocal = 1 / qk_sum_data[row];
473+
vec::map<scalar_t>(
474+
[sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },
475+
out_data + i * oStrideB + j * oStrideH + m * oStrideM +
476+
row * oStrideM,
477+
dst_data + row * headSize,
478+
headSize);
479+
}
480+
// Move to the next query
481+
util::data_index_step(i, batchSize, j, num_head, k, qSlice);
482+
}
483+
};
484+
torch::executor::parallel_for(
485+
0, batchSize * num_head * qSlice, 1, compute_lambda);
492486
}
493487

494488
bool validate_flash_attention_args(

examples/models/llama2/custom_ops/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def define_common_targets():
9999
"//executorch/kernels/portable/cpu:scalar_utils",
100100
"//executorch/kernels/optimized:libblas",
101101
"//executorch/kernels/optimized:libvec",
102+
"//executorch/extension/parallel:thread_parallel",
103+
"//executorch/backends/xnnpack/threadpool:threadpool",
102104
],
103105
compiler_flags = ["-Wno-missing-prototypes"],
104106
visibility = [

examples/models/llama2/main.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,10 @@ int32_t main(int32_t argc, char** argv) {
6363
: static_cast<uint32_t>(cpu_threads);
6464
ET_LOG(
6565
Info, "Resetting threadpool with num threads = %d", num_performant_cores);
66-
torch::executorch::threadpool::get_threadpool()->_unsafe_reset_threadpool(
67-
num_performant_cores);
66+
if (num_performant_cores > 0) {
67+
torch::executorch::threadpool::get_threadpool()->_unsafe_reset_threadpool(
68+
num_performant_cores);
69+
}
6870
#endif
6971
// create llama runner
7072
::torch::executor::Runner runner(model_path, tokenizer_path, temperature);

0 commit comments

Comments
 (0)