@@ -321,11 +321,11 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
321
321
// Determine the number of active blocks per SM
322
322
// parallel_blocks template parameter has no effect on the number of active blocks, so keeping a constant 4 to determine active blocks
323
323
int numActiveBlocks = 1 ;
324
- CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks,
324
+ CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks,
325
325
flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
326
326
327
327
// we want to keep at least `numActiveBlocks` blocks per SM to improve occupancy.
328
- // this kernel operates on `D` tile of seq length. We need to consider how many `D` tiles can be processed in parallel.
328
+ // this kernel operates on `D` tile of seq length. We need to consider how many `D` tiles can be processed in parallel.
329
329
// If there are not enough tiles to process, we can reduce the number of blocks
330
330
const int parallel_blocks = std::min ((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
331
331
@@ -341,7 +341,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
341
341
else if (parallel_blocks >= 8 ) {
342
342
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8 , type_K, type_V, use_logit_softcap>(ctx, dst);
343
343
}
344
- else
344
+ else
345
345
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
346
346
{
347
347
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>(ctx, dst);
@@ -353,7 +353,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
353
353
354
354
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
355
355
int numActiveBlocks = 1 ;
356
- CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks,
356
+ CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks,
357
357
flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
358
358
359
359
const int parallel_blocks = std::min ((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
@@ -370,7 +370,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
370
370
else if (parallel_blocks >= 8 ) {
371
371
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8 , type_K, type_V, use_logit_softcap>(ctx, dst);
372
372
}
373
- else
373
+ else
374
374
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
375
375
{
376
376
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>(ctx, dst);
0 commit comments