Skip to content

Commit 21c84b5

Browse files
CUDA: fix Volta FlashAttention logic (#11615)
1 parent d92cb67 commit 21c84b5

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
561561
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
562562
break;
563563
// case 256:
564-
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
564+
// ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
565565
// break;
566566
default:
567567
GGML_ABORT("fatal error");

ggml/src/ggml-cuda/fattn.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
235235
return;
236236
}
237237

238-
if (!new_mma_available(cc)) {
238+
if (!fp16_mma_available(cc)) {
239239
if (prec == GGML_PREC_DEFAULT) {
240240
if (Q->ne[1] <= 8) {
241241
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
@@ -265,6 +265,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
265265
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
266266
if (cc == GGML_CUDA_CC_VOLTA) {
267267
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
268+
return;
268269
}
269270

270271
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);

0 commit comments

Comments
 (0)