Skip to content

Commit c11d05f

Browse files
committed
llama : force disable flash attention for incompatible models
1 parent cb76d74 commit c11d05f

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

llama.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1823,7 +1823,7 @@ struct llama_hparams {
18231823
float f_logit_scale = 0.0f;
18241824

18251825
bool causal_attn = true;
1826-
bool need_kq_pos = false;
1826+
bool need_kq_pos = false; // currently, we need KQ_pos data for ALiBi-based models
18271827

18281828
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
18291829
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
@@ -6311,6 +6311,8 @@ static struct ggml_tensor * llm_build_kqv(
63116311
GGML_UNUSED(model);
63126312
GGML_UNUSED(n_ctx);
63136313

6314+
// note: if this assert triggers, then some check has failed earlier
6315+
// the idea is to detect during context creation that ALiBi would be used and disable Flash Attention
63146316
GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention");
63156317

63166318
// split cached v into n_head heads (not transposed)
@@ -15114,6 +15116,16 @@ struct llama_context * llama_new_context_with_model(
1511415116
}
1511515117
}
1511615118

15119+
if (cparams.flash_attn && hparams.need_kq_pos) {
15120+
LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__);
15121+
cparams.flash_attn = false;
15122+
}
15123+
15124+
if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) {
15125+
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
15126+
cparams.flash_attn = false;
15127+
}
15128+
1511715129
if (params.seed == LLAMA_DEFAULT_SEED) {
1511815130
params.seed = time(NULL);
1511915131
}

0 commit comments

Comments
 (0)