Skip to content

Commit 2df373a

Browse files
CUDA: fix matrix multiplication algorithm choice (ggml-org#8102)
1 parent 3b099bc commit 2df373a

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

ggml-cuda.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,16 +1924,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19241924
} else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
19251925
// FP32 precision KQV single-batch for batch size 1 without FlashAttention
19261926
ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
1927+
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
1928+
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
1929+
// KQ + KQV multi-batch without FlashAttention
1930+
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
19271931
} else if (use_dequantize_mul_mat_vec) {
19281932
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
19291933
} else if (use_mul_mat_vec_q) {
19301934
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
19311935
} else if (use_mul_mat_q) {
19321936
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
1933-
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
1934-
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
1935-
// KQ + KQV multi-batch without FlashAttention
1936-
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
19371937
} else {
19381938
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
19391939
}

0 commit comments

Comments
 (0)