Skip to content

Commit 656fc8e

Browse files
committed
Review suggestions
+ Add defines to vendors/hip.h and vendors/musa.h
1 parent 153bb26 commit 656fc8e

File tree

3 files changed

+7
-13
lines changed

3 files changed

+7
-13
lines changed

ggml/src/ggml-cuda/fattn-vec-f32.cuh

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -308,16 +308,13 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
308308

309309
if (Q->ne[1] == 1) {
310310
constexpr int cols_per_block = 1;
311-
const int total_blocks = (((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]);
311+
const int num_blocks_base = (((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]);
312312
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
313313
const int seqlen_tiles = (K->ne[1] + D - 1) / D;
314314

315315
if (logit_softcap == 0.0f) {
316316
constexpr bool use_logit_softcap = false;
317317

318-
// cudaOccupancyMaxActiveBlocksPerMultiprocessor is not supported on HIP platform
319-
// so, skipping the occupancy check for HIP platform
320-
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
321318
// Determine the number of active blocks per SM
322319
// parallel_blocks template parameter has no effect on the number of active blocks, so keeping a constant 4 to determine active blocks
323320
int numActiveBlocks = 1;
@@ -327,7 +324,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
327324
// we want to keep at least `numActiveBlocks` blocks per SM to improve occupancy.
328325
// this kernel operates on `D` tile of seq length. We need to consider how many `D` tiles can be processed in parallel.
329326
// If there are not enough tiles to process, we can reduce the number of blocks
330-
const int parallel_blocks = std::min((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
327+
const int parallel_blocks = std::min((nsm * numActiveBlocks) / num_blocks_base, seqlen_tiles);
331328

332329
if (parallel_blocks >= 24) {
333330
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24, type_K, type_V, use_logit_softcap>(ctx, dst);
@@ -341,22 +338,19 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
341338
else if (parallel_blocks >= 8) {
342339
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8, type_K, type_V, use_logit_softcap>(ctx, dst);
343340
}
344-
else
345-
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
346-
{
341+
else {
347342
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4, type_K, type_V, use_logit_softcap>(ctx, dst);
348343
}
349344
}
350345
else
351346
{
352347
constexpr bool use_logit_softcap = true;
353348

354-
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
355349
int numActiveBlocks = 1;
356350
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numActiveBlocks,
357351
flash_attn_vec_ext_f32<D, cols_per_block, 4, type_K, type_V, use_logit_softcap>, D, 0));
358352

359-
const int parallel_blocks = std::min((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
353+
const int parallel_blocks = std::min((nsm * numActiveBlocks) / num_blocks_base, seqlen_tiles);
360354

361355
if (parallel_blocks >= 24) {
362356
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24, type_K, type_V, use_logit_softcap>(ctx, dst);
@@ -370,9 +364,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
370364
else if (parallel_blocks >= 8) {
371365
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8, type_K, type_V, use_logit_softcap>(ctx, dst);
372366
}
373-
else
374-
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
375-
{
367+
else {
376368
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4, type_K, type_V, use_logit_softcap>(ctx, dst);
377369
}
378370
}

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
#define cudaGraph_t hipGraph_t
130130
#define cudaStream_t hipStream_t
131131
#define cudaSuccess hipSuccess
132+
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
132133
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
133134
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
134135
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED

ggml/src/ggml-cuda/vendors/musa.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,5 +133,6 @@
133133
#define cudaKernelNodeParams musaKernelNodeParams
134134
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
135135
#define cudaStreamEndCapture musaStreamEndCapture
136+
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
136137

137138
typedef mt_bfloat16 nv_bfloat16;

0 commit comments

Comments
 (0)