17
17
18
18
#include < array>
19
19
20
+ #ifdef ET_USE_THREADPOOL
21
+ #include < executorch/backends/xnnpack/threadpool/threadpool.h>
22
+ #include < executorch/extension/parallel/thread_parallel.h>
23
+ #endif
24
+
20
25
namespace torch {
21
26
namespace executor {
27
+
22
28
namespace native {
23
29
24
30
namespace util {
25
31
26
32
constexpr size_t kKVDim = 4 ;
27
33
28
34
template <typename T>
29
- inline void _store (T* dst, executorch::vec::Vectorized<T> src) {
35
+ inline void _store (T* dst, :: executorch::vec::Vectorized<T> src) {
30
36
src.store (dst);
31
37
}
32
38
@@ -38,19 +44,6 @@ inline void _store(::Half* dst, at::vec::Vectorized<float> src) {
38
44
}
39
45
*/
40
46
41
- template <class F >
42
- inline void parallel_for (
43
- const int64_t begin,
44
- const int64_t end,
45
- const int64_t grain_size,
46
- const F& f) {
47
- for (int64_t i = begin; i < end; i += grain_size) {
48
- int64_t task_begin = i;
49
- int64_t task_end = std::min (task_begin + grain_size, end);
50
- f (task_begin, task_end);
51
- }
52
- }
53
-
54
47
template <typename T>
55
48
inline T data_index_init (T offset) {
56
49
return offset;
@@ -83,7 +76,7 @@ inline double calculate_scale(const Tensor& query, optional<double> scale) {
83
76
}
84
77
85
78
} // namespace util
86
- namespace vec = executorch::vec;
79
+ namespace vec = :: executorch::vec;
87
80
using Tensor = exec_aten::Tensor;
88
81
89
82
namespace {
@@ -310,8 +303,12 @@ void cpu_flash_attention(
310
303
int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
311
304
int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
312
305
int64_t qSlice = (qSize - 1 ) / qSplitSize + 1 ;
313
- // int64_t num_thread = at::get_num_threads();
314
- int64_t num_thread = 1 ; // at::get_num_threads();
306
+ #ifdef ET_USE_THREADPOOL
307
+ int64_t num_thread =
308
+ torch::executorch::threadpool::get_threadpool ()->get_thread_count ();
309
+ #else
310
+ int64_t num_thread = 1 ;
311
+ #endif
315
312
316
313
// const auto dtype = query.scalar_type();
317
314
// Following will be revisited in the future
@@ -346,149 +343,146 @@ void cpu_flash_attention(
346
343
scalar_t * buf_reduced_data =
347
344
is_reduced_type ? reinterpret_cast <scalar_t *>(buf_reduced) : nullptr ;
348
345
349
- util::parallel_for (
350
- 0 , batchSize * num_head * qSlice, 1 , [&](int64_t begin, int64_t end) {
351
- int64_t i = 0 , j = 0 , k = 0 ;
352
- util::data_index_init (begin, i, batchSize, j, num_head, k, qSlice);
353
- int ompIdx = 0 ; // at::get_thread_num();
354
- accum_t * buf_ptr = buf_data + ompIdx * size_per_thread;
355
- accum_t * qk_data = buf_ptr;
356
- accum_t * qk_max_data = qk_data + qSplitSize * kvSplitSize;
357
- accum_t * qk_sum_data = qk_max_data + qSplitSize;
358
- accum_t * dst_data = qk_sum_data + qSplitSize;
359
- scalar_t * qk_reduced_data = is_reduced_type
360
- ? buf_reduced_data + ompIdx * qSplitSize * kvSplitSize
361
- : nullptr ;
362
-
363
- for (int64_t z = begin; z < end; z++) {
364
- int64_t m = k * qSplitSize;
365
- int64_t qBlockSize = std::min (qSplitSize, qSize - m);
366
- // Initialize max and sum
367
- fill_stub (
368
- qk_max_data,
369
- -std::numeric_limits<accum_t >::infinity (),
370
- qBlockSize);
371
- int64_t num_keys =
372
- is_causal ? std::min (m + qBlockSize, kvSize) : kvSize;
373
- for (int64_t n = 0 ; n < num_keys; n += kvSplitSize) {
374
- int64_t kvBlockSize = std::min (kvSplitSize, kvSize - n);
375
- // Calculate scale * q @ k.T
346
+ auto compute_lambda = [&](int64_t begin, int64_t end) {
347
+ int64_t i = 0 , j = 0 , k = 0 ;
348
+ util::data_index_init (begin, i, batchSize, j, num_head, k, qSlice);
349
+ int ompIdx = torch::executor::get_thread_num ();
350
+ accum_t * buf_ptr = buf_data + ompIdx * size_per_thread;
351
+ accum_t * qk_data = buf_ptr;
352
+ accum_t * qk_max_data = qk_data + qSplitSize * kvSplitSize;
353
+ accum_t * qk_sum_data = qk_max_data + qSplitSize;
354
+ accum_t * dst_data = qk_sum_data + qSplitSize;
355
+ scalar_t * qk_reduced_data = is_reduced_type
356
+ ? buf_reduced_data + ompIdx * qSplitSize * kvSplitSize
357
+ : nullptr ;
358
+
359
+ for (int64_t z = begin; z < end; z++) {
360
+ int64_t m = k * qSplitSize;
361
+ int64_t qBlockSize = std::min (qSplitSize, qSize - m);
362
+ // Initialize max and sum
363
+ fill_stub (
364
+ qk_max_data, -std::numeric_limits<accum_t >::infinity (), qBlockSize);
365
+ int64_t num_keys = is_causal ? std::min (m + qBlockSize, kvSize) : kvSize;
366
+ for (int64_t n = 0 ; n < num_keys; n += kvSplitSize) {
367
+ int64_t kvBlockSize = std::min (kvSplitSize, kvSize - n);
368
+ // Calculate scale * q @ k.T
369
+ fill_stub (qk_data, static_cast <accum_t >(0 ), qSplitSize * kvSplitSize);
370
+ ::executorch::cpublas::gemm (
371
+ ::executorch::cpublas::TransposeType::Transpose,
372
+ ::executorch::cpublas::TransposeType::NoTranspose,
373
+ kvBlockSize,
374
+ qBlockSize,
375
+ headSize,
376
+ static_cast <accum_t >(1 ),
377
+ k_data + i * kStrideB + j * kStrideH + n * kStrideN,
378
+ kStrideN,
379
+ q_data + i * qStrideB + j * qStrideH + m * qStrideM,
380
+ qStrideM,
381
+ static_cast<accum_t>(0 ),
382
+ qk_data,
383
+ kvBlockSize);
384
+ // Apply causal mask, fill unused with -inf
385
+ if (is_causal && num_keys - n <= kvSplitSize) {
386
+ for (int32_t row = 0 ; row < qBlockSize; ++row) {
387
+ int64_t last_col = m + row - n;
388
+ accum_t * row_ptr = qk_data + row * kvBlockSize;
376
389
fill_stub (
377
- qk_data, static_cast <accum_t >(0 ), qSplitSize * kvSplitSize);
378
- executorch::cpublas::gemm (
379
- executorch::cpublas::TransposeType::Transpose,
380
- executorch::cpublas::TransposeType::NoTranspose,
381
- kvBlockSize,
382
- qBlockSize,
383
- headSize,
384
- static_cast <accum_t >(1 ),
385
- k_data + i * kStrideB + j * kStrideH + n * kStrideN ,
386
- kStrideN ,
387
- q_data + i * qStrideB + j * qStrideH + m * qStrideM,
388
- qStrideM,
389
- static_cast <accum_t >(0 ),
390
- qk_data,
390
+ row_ptr + last_col + 1 ,
391
+ -std::numeric_limits<accum_t >::infinity (),
392
+ kvBlockSize - last_col - 1 );
393
+ }
394
+ }
395
+ // Update attention weights with attention mask
396
+ // And apply scaling factor
397
+ // qk <- qk * scaling + attn_mask
398
+ if (has_attn_mask) {
399
+ for (int64_t row = 0 ; row < qBlockSize; ++row) {
400
+ vec::map2<accum_t >(
401
+ [scaling_factor](Vec x, Vec y) {
402
+ return x * Vec (scaling_factor) + y;
403
+ },
404
+ qk_data + row * kvBlockSize,
405
+ qk_data + row * kvBlockSize,
406
+ mask_data + i * mStrideB + j * mStrideH + (m + row) * mStrideM +
407
+ n,
391
408
kvBlockSize);
392
- // Apply causal mask, fill unused with -inf
393
- if (is_causal && num_keys - n <= kvSplitSize) {
394
- for (int32_t row = 0 ; row < qBlockSize; ++row) {
395
- int64_t last_col = m + row - n;
396
- accum_t * row_ptr = qk_data + row * kvBlockSize;
397
- fill_stub (
398
- row_ptr + last_col + 1 ,
399
- -std::numeric_limits<accum_t >::infinity (),
400
- kvBlockSize - last_col - 1 );
401
- }
402
- }
403
- // Update attention weights with attention mask
404
- // And apply scaling factor
405
- // qk <- qk * scaling + attn_mask
406
- if (has_attn_mask) {
407
- for (int64_t row = 0 ; row < qBlockSize; ++row) {
408
- vec::map2<accum_t >(
409
- [scaling_factor](Vec x, Vec y) {
410
- return x * Vec (scaling_factor) + y;
411
- },
412
- qk_data + row * kvBlockSize,
413
- qk_data + row * kvBlockSize,
414
- mask_data + i * mStrideB + j * mStrideH +
415
- (m + row) * mStrideM + n,
416
- kvBlockSize);
417
- }
418
- }
419
- // Update coefficients with Softmax
420
- accum_t tmp_max = 0 , tmp_sum = 0 , exp_tmp = 0 ;
421
- for (int64_t row = 0 ; row < qBlockSize; ++row) {
422
- if (has_attn_mask) {
423
- // max per row
424
- tmp_max = vec::reduce_all<accum_t >(
425
- [](Vec& x, Vec& y) { return vec::maximum (x, y); },
426
- qk_data + row * kvBlockSize,
427
- kvBlockSize);
428
- } else {
429
- // apply scaling factor and max per row in fusion
430
- _mul_reduce_max_fusion_kernel (
431
- qk_data + row * kvBlockSize,
432
- scaling_factor,
433
- kvBlockSize,
434
- qk_data + row * kvBlockSize,
435
- tmp_max);
436
- }
437
- tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
438
- // qk <- exp(qk - max) and sum per row
439
- tmp_sum = tmp_max;
440
- _exp_reduce_sum_fusion_kernel (
441
- qk_data + row * kvBlockSize,
442
- kvBlockSize,
443
- conditional_data_ptr (qk_data, qk_reduced_data) +
444
- row * kvBlockSize,
445
- tmp_sum);
446
- // exp_tmp <- exp(max[row] - max)
447
- exp_tmp = std::exp (qk_max_data[row] - tmp_max);
448
- // sum[row] <- sum + exp_tmp * sum[row]
449
- qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
450
- // max[row] <- max
451
- qk_max_data[row] = tmp_max;
452
- // dst <- dst * exp_tmp
453
- if (n > 0 ) {
454
- vec::map<accum_t >(
455
- [exp_tmp](Vec x) { return x * Vec (exp_tmp); },
456
- dst_data + row * headSize,
457
- dst_data + row * headSize,
458
- headSize);
459
- }
460
- }
461
- // Calculate Softmax(q @ k.T) @ v
462
- executorch::cpublas::gemm (
463
- executorch::cpublas::TransposeType::NoTranspose,
464
- executorch::cpublas::TransposeType::NoTranspose,
465
- headSize,
466
- qBlockSize,
467
- kvBlockSize,
468
- static_cast <accum_t >(1 ),
469
- v_data + i * vStrideB + j * vStrideH + n * vStrideN,
470
- vStrideN,
471
- conditional_data_ptr (qk_data, qk_reduced_data),
409
+ }
410
+ }
411
+ // Update coefficients with Softmax
412
+ accum_t tmp_max = 0 , tmp_sum = 0 , exp_tmp = 0 ;
413
+ for (int64_t row = 0 ; row < qBlockSize; ++row) {
414
+ if (has_attn_mask) {
415
+ // max per row
416
+ tmp_max = vec::reduce_all<accum_t >(
417
+ [](Vec& x, Vec& y) { return vec::maximum (x, y); },
418
+ qk_data + row * kvBlockSize,
419
+ kvBlockSize);
420
+ } else {
421
+ // apply scaling factor and max per row in fusion
422
+ _mul_reduce_max_fusion_kernel (
423
+ qk_data + row * kvBlockSize,
424
+ scaling_factor,
472
425
kvBlockSize,
473
- n == 0 ? static_cast <accum_t >(0 ) : static_cast <accum_t >(1 ),
474
- dst_data,
475
- headSize);
426
+ qk_data + row * kvBlockSize,
427
+ tmp_max);
476
428
}
477
- // dst <- dst / sum[row]
478
- // reorder MHA output with strides
479
- for (int64_t row = 0 ; row < qBlockSize; ++row) {
480
- accum_t sum_reciprocal = 1 / qk_sum_data[row];
481
- vec::map<scalar_t >(
482
- [sum_reciprocal](Vec x) { return x * Vec (sum_reciprocal); },
483
- out_data + i * oStrideB + j * oStrideH + m * oStrideM +
484
- row * oStrideM,
429
+ tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
430
+ // qk <- exp(qk - max) and sum per row
431
+ tmp_sum = tmp_max;
432
+ _exp_reduce_sum_fusion_kernel (
433
+ qk_data + row * kvBlockSize,
434
+ kvBlockSize,
435
+ conditional_data_ptr (qk_data, qk_reduced_data) +
436
+ row * kvBlockSize,
437
+ tmp_sum);
438
+ // exp_tmp <- exp(max[row] - max)
439
+ exp_tmp = std::exp (qk_max_data[row] - tmp_max);
440
+ // sum[row] <- sum + exp_tmp * sum[row]
441
+ qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
442
+ // max[row] <- max
443
+ qk_max_data[row] = tmp_max;
444
+ // dst <- dst * exp_tmp
445
+ if (n > 0 ) {
446
+ vec::map<accum_t >(
447
+ [exp_tmp](Vec x) { return x * Vec (exp_tmp); },
448
+ dst_data + row * headSize,
485
449
dst_data + row * headSize,
486
450
headSize);
487
451
}
488
- // Move to the next query
489
- util::data_index_step (i, batchSize, j, num_head, k, qSlice);
490
452
}
491
- });
453
+ // Calculate Softmax(q @ k.T) @ v
454
+ ::executorch::cpublas::gemm (
455
+ ::executorch::cpublas::TransposeType::NoTranspose,
456
+ ::executorch::cpublas::TransposeType::NoTranspose,
457
+ headSize,
458
+ qBlockSize,
459
+ kvBlockSize,
460
+ static_cast <accum_t >(1 ),
461
+ v_data + i * vStrideB + j * vStrideH + n * vStrideN,
462
+ vStrideN,
463
+ conditional_data_ptr(qk_data, qk_reduced_data),
464
+ kvBlockSize,
465
+ n == 0 ? static_cast<accum_t>(0 ) : static_cast<accum_t>(1 ),
466
+ dst_data,
467
+ headSize);
468
+ }
469
+ // dst <- dst / sum[row]
470
+ // reorder MHA output with strides
471
+ for (int64_t row = 0 ; row < qBlockSize; ++row) {
472
+ accum_t sum_reciprocal = 1 / qk_sum_data[row];
473
+ vec::map<scalar_t >(
474
+ [sum_reciprocal](Vec x) { return x * Vec (sum_reciprocal); },
475
+ out_data + i * oStrideB + j * oStrideH + m * oStrideM +
476
+ row * oStrideM,
477
+ dst_data + row * headSize,
478
+ headSize);
479
+ }
480
+ // Move to the next query
481
+ util::data_index_step (i, batchSize, j, num_head, k, qSlice);
482
+ }
483
+ };
484
+ torch::executor::parallel_for (
485
+ 0 , batchSize * num_head * qSlice, 1 , compute_lambda);
492
486
}
493
487
494
488
bool validate_flash_attention_args (
0 commit comments