@@ -308,16 +308,13 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
308
308
309
309
if (Q->ne [1 ] == 1 ) {
310
310
constexpr int cols_per_block = 1 ;
311
- const int total_blocks = (((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block)*Q->ne [2 ]*Q->ne [3 ]);
311
+ const int num_blocks_base = (((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block)*Q->ne [2 ]*Q->ne [3 ]);
312
312
const int nsm = ggml_cuda_info ().devices [ggml_cuda_get_device ()].nsm ;
313
313
const int seqlen_tiles = (K->ne [1 ] + D - 1 ) / D;
314
314
315
315
if (logit_softcap == 0 .0f ) {
316
316
constexpr bool use_logit_softcap = false ;
317
317
318
- // cudaOccupancyMaxActiveBlocksPerMultiprocessor is not supported on HIP platform
319
- // so, skipping the occupancy check for HIP platform
320
- #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
321
318
// Determine the number of active blocks per SM
322
319
// parallel_blocks template parameter has no effect on the number of active blocks, so keeping a constant 4 to determine active blocks
323
320
int numActiveBlocks = 1 ;
@@ -327,7 +324,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
327
324
// we want to keep at least `numActiveBlocks` blocks per SM to improve occupancy.
328
325
// this kernel operates on `D` tile of seq length. We need to consider how many `D` tiles can be processed in parallel.
329
326
// If there are not enough tiles to process, we can reduce the number of blocks
330
- const int parallel_blocks = std::min ((nsm * numActiveBlocks) / total_blocks , seqlen_tiles);
327
+ const int parallel_blocks = std::min ((nsm * numActiveBlocks) / num_blocks_base , seqlen_tiles);
331
328
332
329
if (parallel_blocks >= 24 ) {
333
330
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24 , type_K, type_V, use_logit_softcap>(ctx, dst);
@@ -341,22 +338,19 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
341
338
else if (parallel_blocks >= 8 ) {
342
339
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8 , type_K, type_V, use_logit_softcap>(ctx, dst);
343
340
}
344
- else
345
- #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
346
- {
341
+ else {
347
342
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>(ctx, dst);
348
343
}
349
344
}
350
345
else
351
346
{
352
347
constexpr bool use_logit_softcap = true ;
353
348
354
- #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
355
349
int numActiveBlocks = 1 ;
356
350
CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks,
357
351
flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
358
352
359
- const int parallel_blocks = std::min ((nsm * numActiveBlocks) / total_blocks , seqlen_tiles);
353
+ const int parallel_blocks = std::min ((nsm * numActiveBlocks) / num_blocks_base , seqlen_tiles);
360
354
361
355
if (parallel_blocks >= 24 ) {
362
356
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24 , type_K, type_V, use_logit_softcap>(ctx, dst);
@@ -370,9 +364,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
370
364
else if (parallel_blocks >= 8 ) {
371
365
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8 , type_K, type_V, use_logit_softcap>(ctx, dst);
372
366
}
373
- else
374
- #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
375
- {
367
+ else {
376
368
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>(ctx, dst);
377
369
}
378
370
}
0 commit comments