Skip to content

Commit 3a8d954

Browse files
committed
vulkan: always use fp32 for scalar flash attention
1 parent 005756a commit 3a8d954

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5726,9 +5726,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
57265726
assert(k->type == v->type);
57275727

57285728
vk_pipeline *pipelines;
5729-
// XXX TODO other backends may be changing accumulator precision to default to f32 soon
5730-
bool f32acc = dst->op_params[3] == GGML_PREC_F32;
57315729
bool scalar = !ctx->device->coopmat2;
5730+
// XXX TODO other backends may be changing accumulator precision to default to f32 soon
5731+
bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32;
57325732
bool small_rows = N <= get_fa_num_small_rows(scalar);
57335733

57345734
if (scalar) {

0 commit comments

Comments
 (0)