Skip to content

Commit c6f1177

Browse files
JohannesGaesslerorca-zhang
authored andcommitted
CUDA: fix Volta FlashAttention logic (ggml-org#11615)
1 parent 155035b commit c6f1177

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
561561
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
562562
break;
563563
// case 256:
564-
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
564+
// ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
565565
// break;
566566
default:
567567
GGML_ABORT("fatal error");

ggml/src/ggml-cuda/fattn.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
347347
return;
348348
}
349349

350-
if (!new_mma_available(cc)) {
350+
if (!fp16_mma_available(cc)) {
351351
if (prec == GGML_PREC_DEFAULT) {
352352
if (Q->ne[1] <= 8) {
353353
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
@@ -377,6 +377,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
377377
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
378378
if (cc == GGML_CUDA_CC_VOLTA) {
379379
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
380+
return;
380381
}
381382

382383
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);

0 commit comments

Comments
 (0)