@@ -308,13 +308,72 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
308
308
309
309
if (Q->ne [1 ] == 1 ) {
310
310
constexpr int cols_per_block = 1 ;
311
- constexpr int parallel_blocks = 4 ;
311
+ const int total_blocks = (((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block)*Q->ne [2 ]*Q->ne [3 ]);
312
+ const int nsm = ggml_cuda_info ().devices [ggml_cuda_get_device ()].nsm ;
313
+ const int seqlen_tiles = (K->ne [1 ] + D - 1 ) / D;
314
+
312
315
if (logit_softcap == 0 .0f ) {
313
316
constexpr bool use_logit_softcap = false ;
314
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
315
- } else {
317
+
318
+ // Determine the number of active blocks per SM
319
+ // parallel_blocks template parameter has no effect on the number of active blocks, so keeping a constant 4 to determine active blocks
320
+ int numActiveBlocks = 1 ;
321
+ CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks, flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
322
+
323
+ // we want to keep at least `numActiveBlocks` blocks per SM to improve occupancy.
324
+ // this kernel operates on `D` tile of seq length. We need to consider how many `D` tiles can be processed in parallel.
325
+ // If there are not enough tiles to process, we can reduce the number of blocks
326
+ const int parallel_blocks = std::min ((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
327
+
328
+ if (parallel_blocks >= 24 )
329
+ {
330
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24 , type_K, type_V, use_logit_softcap>(ctx, dst);
331
+ }
332
+ else if (parallel_blocks >= 16 )
333
+ {
334
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 16 , type_K, type_V, use_logit_softcap>(ctx, dst);
335
+ }
336
+ else if (parallel_blocks >= 12 )
337
+ {
338
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 12 , type_K, type_V, use_logit_softcap>(ctx, dst);
339
+ }
340
+ else if (parallel_blocks >= 8 )
341
+ {
342
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8 , type_K, type_V, use_logit_softcap>(ctx, dst);
343
+ }
344
+ else
345
+ {
346
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>(ctx, dst);
347
+ }
348
+ }
349
+ else
350
+ {
316
351
constexpr bool use_logit_softcap = true ;
317
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
352
+ int numActiveBlocks = 1 ;
353
+ CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks, flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
354
+
355
+ const int parallel_blocks = std::min ((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
356
+
357
+ if (parallel_blocks >= 24 )
358
+ {
359
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24 , type_K, type_V, use_logit_softcap>(ctx, dst);
360
+ }
361
+ else if (parallel_blocks >= 16 )
362
+ {
363
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 16 , type_K, type_V, use_logit_softcap>(ctx, dst);
364
+ }
365
+ else if (parallel_blocks >= 12 )
366
+ {
367
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 12 , type_K, type_V, use_logit_softcap>(ctx, dst);
368
+ }
369
+ else if (parallel_blocks >= 8 )
370
+ {
371
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8 , type_K, type_V, use_logit_softcap>(ctx, dst);
372
+ }
373
+ else
374
+ {
375
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>(ctx, dst);
376
+ }
318
377
}
319
378
return ;
320
379
}
0 commit comments