@@ -1590,7 +1590,8 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
1590
1590
1591
1591
// number of rows/cols for flash attention shader
1592
1592
static constexpr uint32_t flash_attention_num_small_rows = 32;
1593
- static constexpr uint32_t scalar_flash_attention_num_small_rows = 8;
1593
+ static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
1594
+ static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
1594
1595
1595
1596
static uint32_t get_fa_num_small_rows(bool scalar) {
1596
1597
return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows;
@@ -1599,8 +1600,16 @@ static uint32_t get_fa_num_small_rows(bool scalar) {
1599
1600
static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
1600
1601
GGML_UNUSED(clamp);
1601
1602
1603
+ if (scalar) {
1604
+ if (small_rows) {
1605
+ return {scalar_flash_attention_num_small_rows, 64};
1606
+ } else {
1607
+ return {scalar_flash_attention_num_large_rows, 32};
1608
+ }
1609
+ }
1610
+
1602
1611
// small rows, large cols
1603
- if (small_rows || scalar ) {
1612
+ if (small_rows) {
1604
1613
return {get_fa_num_small_rows(scalar), 32};
1605
1614
}
1606
1615
@@ -5729,8 +5738,29 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5729
5738
assert(q->type == GGML_TYPE_F32);
5730
5739
assert(k->type == v->type);
5731
5740
5732
- vk_pipeline *pipelines;
5733
5741
bool scalar = !ctx->device->coopmat2;
5742
+
5743
+ uint32_t gqa_ratio = 1;
5744
+ uint32_t qk_ratio = neq2 / nek2;
5745
+ uint32_t workgroups_x = (uint32_t)neq1;
5746
+ uint32_t workgroups_y = (uint32_t)neq2;
5747
+ uint32_t workgroups_z = (uint32_t)neq3;
5748
+
5749
+ // For scalar FA, we can use the "large" size to accommodate qga.
5750
+ // For coopmat FA, we always use the small size (which is still pretty large for gqa).
5751
+ const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false);
5752
+
5753
+ if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
5754
+ qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
5755
+ // grouped query attention - make the N dimension equal to gqa_ratio, reduce
5756
+ // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
5757
+ // and change addressing calculations to index Q's dimension 2.
5758
+ gqa_ratio = qk_ratio;
5759
+ N = gqa_ratio;
5760
+ workgroups_y /= N;
5761
+ }
5762
+
5763
+ vk_pipeline *pipelines;
5734
5764
// XXX TODO other backends may be changing accumulator precision to default to f32 soon
5735
5765
bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32;
5736
5766
bool small_rows = N <= get_fa_num_small_rows(scalar);
@@ -5776,24 +5806,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5776
5806
vk_pipeline pipeline = pipelines[aligned];
5777
5807
assert(pipeline);
5778
5808
5779
- uint32_t gqa_ratio = 1;
5780
- uint32_t qk_ratio = neq2 / nek2;
5781
- uint32_t workgroups_x = (uint32_t)neq1;
5782
- uint32_t workgroups_y = (uint32_t)neq2;
5783
- uint32_t workgroups_z = (uint32_t)neq3;
5784
-
5785
- const uint32_t max_gqa = get_fa_num_small_rows(scalar);
5786
-
5787
- if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
5788
- qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
5789
- // grouped query attention - make the N dimension equal to gqa_ratio, reduce
5790
- // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
5791
- // and change addressing calculations to index Q's dimension 2.
5792
- gqa_ratio = qk_ratio;
5793
- N = gqa_ratio;
5794
- workgroups_y /= N;
5795
- }
5796
-
5797
5809
uint32_t split_kv = KV;
5798
5810
uint32_t split_k = 1;
5799
5811
0 commit comments