Skip to content

Commit 177f5af

Browse files
committed
Fix compilation errors on HIP and MUSA
1 parent 76881ac commit 177f5af

File tree

2 files changed

+22
-24
lines changed

2 files changed

+22
-24
lines changed

ggml/src/ggml-cuda/fattn-vec-f32.cuh

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -315,66 +315,68 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
315315
if (logit_softcap == 0.0f) {
316316
constexpr bool use_logit_softcap = false;
317317

318+
// cudaOccupancyMaxActiveBlocksPerMultiprocessor is not supported on HIP platform
319+
// so, skipping the occupancy check for HIP platform
320+
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
318321
// Determine the number of active blocks per SM
319322
// parallel_blocks template parameter has no effect on the number of active blocks, so keeping a constant 4 to determine active blocks
320323
int numActiveBlocks = 1;
321-
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numActiveBlocks, flash_attn_vec_ext_f32<D, cols_per_block, 4, type_K, type_V, use_logit_softcap>, D, 0));
324+
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numActiveBlocks,
325+
flash_attn_vec_ext_f32<D, cols_per_block, 4, type_K, type_V, use_logit_softcap>, D, 0));
322326

323327
// we want to keep at least `numActiveBlocks` blocks per SM to improve occupancy.
324328
// this kernel operates on `D` tile of seq length. We need to consider how many `D` tiles can be processed in parallel.
325329
// If there are not enough tiles to process, we can reduce the number of blocks
326330
const int parallel_blocks = std::min((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
327331

328-
if(parallel_blocks >= 24)
329-
{
332+
if(parallel_blocks >= 24) {
330333
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24, type_K, type_V, use_logit_softcap>(ctx, dst);
331334
}
332-
else if(parallel_blocks >= 16)
333-
{
335+
else if(parallel_blocks >= 16) {
334336
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 16, type_K, type_V, use_logit_softcap>(ctx, dst);
335337
}
336-
else if(parallel_blocks >= 12)
337-
{
338+
else if(parallel_blocks >= 12) {
338339
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 12, type_K, type_V, use_logit_softcap>(ctx, dst);
339340
}
340-
else if(parallel_blocks >= 8)
341-
{
341+
else if(parallel_blocks >= 8) {
342342
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8, type_K, type_V, use_logit_softcap>(ctx, dst);
343343
}
344-
else
344+
else
345+
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
345346
{
346347
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4, type_K, type_V, use_logit_softcap>(ctx, dst);
347348
}
348349
}
349350
else
350351
{
351352
constexpr bool use_logit_softcap = true;
353+
354+
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
352355
int numActiveBlocks = 1;
353-
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numActiveBlocks, flash_attn_vec_ext_f32<D, cols_per_block, 4, type_K, type_V, use_logit_softcap>, D, 0));
356+
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numActiveBlocks,
357+
flash_attn_vec_ext_f32<D, cols_per_block, 4, type_K, type_V, use_logit_softcap>, D, 0));
354358

355359
const int parallel_blocks = std::min((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
356360

357-
if(parallel_blocks >= 24)
358-
{
361+
if(parallel_blocks >= 24) {
359362
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24, type_K, type_V, use_logit_softcap>(ctx, dst);
360363
}
361-
else if(parallel_blocks >= 16)
362-
{
364+
else if(parallel_blocks >= 16) {
363365
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 16, type_K, type_V, use_logit_softcap>(ctx, dst);
364366
}
365-
else if(parallel_blocks >= 12)
366-
{
367+
else if(parallel_blocks >= 12) {
367368
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 12, type_K, type_V, use_logit_softcap>(ctx, dst);
368369
}
369-
else if(parallel_blocks >= 8)
370-
{
370+
else if(parallel_blocks >= 8) {
371371
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8, type_K, type_V, use_logit_softcap>(ctx, dst);
372372
}
373-
else
373+
else
374+
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
374375
{
375376
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4, type_K, type_V, use_logit_softcap>(ctx, dst);
376377
}
377378
}
379+
378380
return;
379381
}
380382

ggml/src/ggml-cuda/fattn.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,6 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
244244
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
245245
const ggml_tensor * KQV = dst;
246246
const ggml_tensor * Q = dst->src[0];
247-
const ggml_tensor * K = dst->src[1];
248-
const ggml_tensor * V = dst->src[2];
249-
const ggml_tensor * mask = dst->src[3];
250247

251248
ggml_cuda_set_device(ctx.device);
252249
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
@@ -296,7 +293,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
296293
return;
297294
}
298295

299-
const int gqa_ratio = Q->ne[2] / K->ne[2];
300296
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
301297
if (prec == GGML_PREC_DEFAULT) {
302298
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);

0 commit comments

Comments
 (0)