@@ -315,66 +315,68 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
315
315
if (logit_softcap == 0 .0f ) {
316
316
constexpr bool use_logit_softcap = false ;
317
317
318
+ // cudaOccupancyMaxActiveBlocksPerMultiprocessor is not supported on HIP platform
319
+ // so, skipping the occupancy check for HIP platform
320
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
318
321
// Determine the number of active blocks per SM
319
322
// parallel_blocks template parameter has no effect on the number of active blocks, so keeping a constant 4 to determine active blocks
320
323
int numActiveBlocks = 1 ;
321
- CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks, flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
324
+ CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks,
325
+ flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
322
326
323
327
// we want to keep at least `numActiveBlocks` blocks per SM to improve occupancy.
324
328
// this kernel operates on `D` tile of seq length. We need to consider how many `D` tiles can be processed in parallel.
325
329
// If there are not enough tiles to process, we can reduce the number of blocks
326
330
const int parallel_blocks = std::min ((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
327
331
328
- if (parallel_blocks >= 24 )
329
- {
332
+ if (parallel_blocks >= 24 ) {
330
333
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24 , type_K, type_V, use_logit_softcap>(ctx, dst);
331
334
}
332
- else if (parallel_blocks >= 16 )
333
- {
335
+ else if (parallel_blocks >= 16 ) {
334
336
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 16 , type_K, type_V, use_logit_softcap>(ctx, dst);
335
337
}
336
- else if (parallel_blocks >= 12 )
337
- {
338
+ else if (parallel_blocks >= 12 ) {
338
339
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 12 , type_K, type_V, use_logit_softcap>(ctx, dst);
339
340
}
340
- else if (parallel_blocks >= 8 )
341
- {
341
+ else if (parallel_blocks >= 8 ) {
342
342
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8 , type_K, type_V, use_logit_softcap>(ctx, dst);
343
343
}
344
- else
344
+ else
345
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
345
346
{
346
347
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>(ctx, dst);
347
348
}
348
349
}
349
350
else
350
351
{
351
352
constexpr bool use_logit_softcap = true ;
353
+
354
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
352
355
int numActiveBlocks = 1 ;
353
- CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks, flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
356
+ CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks,
357
+ flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
354
358
355
359
const int parallel_blocks = std::min ((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
356
360
357
- if (parallel_blocks >= 24 )
358
- {
361
+ if (parallel_blocks >= 24 ) {
359
362
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24 , type_K, type_V, use_logit_softcap>(ctx, dst);
360
363
}
361
- else if (parallel_blocks >= 16 )
362
- {
364
+ else if (parallel_blocks >= 16 ) {
363
365
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 16 , type_K, type_V, use_logit_softcap>(ctx, dst);
364
366
}
365
- else if (parallel_blocks >= 12 )
366
- {
367
+ else if (parallel_blocks >= 12 ) {
367
368
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 12 , type_K, type_V, use_logit_softcap>(ctx, dst);
368
369
}
369
- else if (parallel_blocks >= 8 )
370
- {
370
+ else if (parallel_blocks >= 8 ) {
371
371
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8 , type_K, type_V, use_logit_softcap>(ctx, dst);
372
372
}
373
- else
373
+ else
374
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
374
375
{
375
376
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>(ctx, dst);
376
377
}
377
378
}
379
+
378
380
return ;
379
381
}
380
382
0 commit comments