@@ -342,6 +342,7 @@ struct vk_device_struct {
342
342
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2 ][2 ][2 ];
343
343
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2 ][2 ][2 ];
344
344
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2 ][2 ][2 ];
345
+ vk_pipeline pipeline_flash_attn_split_k_reduce;
345
346
346
347
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
347
348
std::unordered_map<std::string, uint64_t > pipeline_descriptor_set_requirements;
@@ -493,6 +494,8 @@ struct vk_flash_attn_push_constants {
493
494
float m1;
494
495
495
496
uint32_t gqa_ratio;
497
+ uint32_t split_kv;
498
+ uint32_t k_num;
496
499
};
497
500
498
501
struct vk_op_push_constants {
@@ -1465,7 +1468,7 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ
1465
1468
1466
1469
// small rows, large cols
1467
1470
if (small_rows) {
1468
- return {flash_attention_num_small_rows, 128 };
1471
+ return {flash_attention_num_small_rows, 64 };
1469
1472
}
1470
1473
// small cols to reduce register count
1471
1474
if (ggml_is_quantized (type) || D == 256 ) {
@@ -2269,6 +2272,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2269
2272
ggml_vk_create_pipeline (device, device->pipeline_get_rows_f32 [GGML_TYPE_IQ4_NL], " get_rows_iq4_nl_f32" , get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {1024 , 1 , 1 }, {}, 1 );
2270
2273
2271
2274
ggml_vk_create_pipeline (device, device->pipeline_matmul_split_k_reduce , " split_k_reduce" , split_k_reduce_len, split_k_reduce_data, " main" , 2 , 2 * sizeof (uint32_t ), {256 * 4 , 1 , 1 }, {}, 1 );
2275
+ ggml_vk_create_pipeline (device, device->pipeline_flash_attn_split_k_reduce , " fa_split_k_reduce" , fa_split_k_reduce_len, fa_split_k_reduce_data, " main" , 2 , 3 * sizeof (uint32_t ), {1 , 1 , 1 }, {}, 1 , true );
2272
2276
2273
2277
for (uint32_t i = 0 ; i < p021_max_gqa_ratio; ++i) {
2274
2278
if (device->subgroup_add && device->subgroup_require_full_support ) {
@@ -5309,9 +5313,38 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5309
5313
workgroups_y /= N;
5310
5314
}
5311
5315
5316
+ uint32_t split_kv = KV;
5317
+ uint32_t split_k = 1 ;
5318
+
5319
+ if (gqa_ratio > 1 && ctx->device ->shader_core_count > 0 ) {
5320
+ GGML_ASSERT (workgroups_x == 1 );
5321
+ // Try to run two workgroups per SM.
5322
+ split_k = ctx->device ->shader_core_count * 2 / workgroups_y;
5323
+ if (split_k > 1 ) {
5324
+ // Try to evenly split KV into split_k chunks, but it needs to be a multiple
5325
+ // of "align", so recompute split_k based on that.
5326
+ split_kv = ROUNDUP_POW2 (KV / split_k, pipelines[1 ]->align );
5327
+ split_k = CEIL_DIV (KV, split_kv);
5328
+ workgroups_x = split_k;
5329
+ }
5330
+ }
5331
+
5332
+ // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
5333
+ // and the per-row m and L values (ne1 rows).
5334
+ const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof (float ) + ne1 * sizeof (float ) * 2 ) * split_k : 0 ;
5335
+ if (split_k_size > ctx->device ->max_memory_allocation_size ) {
5336
+ GGML_ABORT (" Requested preallocation size is too large" );
5337
+ }
5338
+ if (ctx->prealloc_size_split_k < split_k_size) {
5339
+ ctx->prealloc_size_split_k = split_k_size;
5340
+ }
5341
+
5312
5342
if (dryrun) {
5313
5343
// Request descriptor sets
5314
5344
ggml_pipeline_request_descriptor_sets (ctx->device , pipeline, 1 );
5345
+ if (split_k > 1 ) {
5346
+ ggml_pipeline_request_descriptor_sets (ctx->device , ctx->device ->pipeline_flash_attn_split_k_reduce , 1 );
5347
+ }
5315
5348
return ;
5316
5349
}
5317
5350
@@ -5332,8 +5365,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5332
5365
const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
5333
5366
const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
5334
5367
5335
- ggml_vk_sync_buffers (subctx);
5336
-
5337
5368
vk_buffer d_Q = nullptr , d_K = nullptr , d_V = nullptr , d_D = nullptr , d_M = nullptr ;
5338
5369
size_t q_buf_offset = 0 , k_buf_offset = 0 , v_buf_offset = 0 , d_buf_offset = 0 , m_buf_offset = 0 ;
5339
5370
@@ -5398,16 +5429,45 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5398
5429
v_stride, (uint32_t )nbv2, (uint32_t )nbv3,
5399
5430
nbm1,
5400
5431
scale, max_bias, logit_softcap,
5401
- mask != nullptr , n_head_log2, m0, m1, gqa_ratio };
5402
- ggml_vk_dispatch_pipeline (ctx, subctx, pipeline,
5403
- {
5404
- vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5405
- vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5406
- vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5407
- vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5408
- vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5409
- },
5410
- sizeof (vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
5432
+ mask != nullptr , n_head_log2, m0, m1,
5433
+ gqa_ratio, split_kv, split_k };
5434
+
5435
+ ggml_vk_sync_buffers (subctx);
5436
+
5437
+ if (split_k > 1 ) {
5438
+ ggml_vk_dispatch_pipeline (ctx, subctx, pipeline,
5439
+ {
5440
+ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5441
+ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5442
+ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5443
+ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5444
+ vk_subbuffer{ctx->prealloc_split_k , 0 , VK_WHOLE_SIZE},
5445
+ },
5446
+ // We only use split_k when group query attention is enabled, which means
5447
+ // there's no more than one tile of rows (i.e. workgroups_x would have been
5448
+ // one). We reuse workgroups_x to mean the number of splits, so we need to
5449
+ // cancel out the divide by wg_denoms[0].
5450
+ sizeof (vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms [0 ], workgroups_y, workgroups_z });
5451
+
5452
+ ggml_vk_sync_buffers (subctx);
5453
+ const std::array<uint32_t , 3 > pc2 = { D, (uint32_t )ne1, split_k };
5454
+ ggml_vk_dispatch_pipeline (ctx, subctx, ctx->device ->pipeline_flash_attn_split_k_reduce ,
5455
+ {
5456
+ vk_subbuffer{ctx->prealloc_split_k , 0 , VK_WHOLE_SIZE},
5457
+ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5458
+ },
5459
+ pc2.size () * uint32_t {sizeof (uint32_t )}, pc2.data (), { (uint32_t )ne1, 1 , 1 });
5460
+ } else {
5461
+ ggml_vk_dispatch_pipeline (ctx, subctx, pipeline,
5462
+ {
5463
+ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5464
+ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5465
+ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5466
+ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5467
+ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5468
+ },
5469
+ sizeof (vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
5470
+ }
5411
5471
}
5412
5472
5413
5473
static vk_pipeline ggml_vk_op_get_pipeline (ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
0 commit comments