Skip to content

Commit 73075d5

Browse files
jeffbolznvinfil00p
authored andcommitted
vulkan: use scalar FA rather than coopmat2 when N==1 (ggml-org#13554)
1 parent 68b2046 commit 73075d5

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5874,10 +5874,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
58745874
vk_pipeline *pipelines;
58755875
bool small_rows = N <= get_fa_num_small_rows(path);
58765876

5877+
// coopmat1 does not actually support "small rows" (it needs 16 rows).
5878+
// So use scalar instead.
58775879
if (small_rows && path == FA_COOPMAT1) {
58785880
path = FA_SCALAR;
58795881
}
58805882

5883+
// scalar is faster than coopmat2 when N==1
5884+
if (N == 1 && path == FA_COOPMAT2) {
5885+
path = FA_SCALAR;
5886+
}
5887+
58815888
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
58825889

58835890
switch (path) {

0 commit comments

Comments
 (0)