@@ -326,6 +326,8 @@ enum llm_kv {
326
326
LLM_KV_POOLING_TYPE,
327
327
LLM_KV_LOGIT_SCALE,
328
328
LLM_KV_DECODER_START_TOKEN_ID,
329
+ LLM_KV_ATTN_LOGIT_SOFTCAPPING,
330
+ LLM_KV_FINAL_LOGIT_SOFTCAPPING,
329
331
330
332
LLM_KV_ATTENTION_HEAD_COUNT,
331
333
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -416,6 +418,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
416
418
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
417
419
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
418
420
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
421
+ { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
422
+ { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
419
423
420
424
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
421
425
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -2127,6 +2131,9 @@ struct llama_hparams {
2127
2131
float f_norm_eps;
2128
2132
float f_norm_rms_eps;
2129
2133
2134
+ float f_attn_logit_softcapping = 50.0f;
2135
+ float f_final_logit_softcapping = 30.0f;
2136
+
2130
2137
float rope_attn_factor = 1.0f;
2131
2138
float rope_freq_base_train;
2132
2139
float rope_freq_scale_train;
@@ -2143,8 +2150,9 @@ struct llama_hparams {
2143
2150
float f_max_alibi_bias = 0.0f;
2144
2151
float f_logit_scale = 0.0f;
2145
2152
2146
- bool causal_attn = true;
2147
- bool use_alibi = false;
2153
+ bool causal_attn = true;
2154
+ bool use_alibi = false;
2155
+ bool attn_soft_cap = false;
2148
2156
2149
2157
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
2150
2158
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
@@ -4755,6 +4763,9 @@ static void llm_load_hparams(
4755
4763
case LLM_ARCH_GEMMA2:
4756
4764
{
4757
4765
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
4766
+ ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
4767
+ ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
4768
+ hparams.attn_soft_cap = true;
4758
4769
4759
4770
switch (hparams.n_layer) {
4760
4771
case 42: model.type = e_model::MODEL_9B; break;
@@ -7658,6 +7669,12 @@ static struct ggml_tensor * llm_build_kqv(
7658
7669
kq = ggml_scale(ctx, kq, 30);
7659
7670
}
7660
7671
7672
+ if (hparams.attn_soft_cap) {
7673
+ kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
7674
+ kq = ggml_tanh(ctx, kq);
7675
+ kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
7676
+ }
7677
+
7661
7678
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
7662
7679
cb(kq, "kq_soft_max_ext", il);
7663
7680
@@ -11118,7 +11135,7 @@ struct llm_build_context {
11118
11135
ext_factor, attn_factor, beta_fast, beta_slow);
11119
11136
cb(Qcur, "Qcur", il);
11120
11137
11121
- Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k )));
11138
+ Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head )));
11122
11139
cb(Qcur, "Qcur_scaled", il);
11123
11140
11124
11141
Kcur = ggml_rope_ext(
@@ -11185,6 +11202,12 @@ struct llm_build_context {
11185
11202
11186
11203
// lm_head
11187
11204
cur = ggml_mul_mat(ctx0, model.output, cur);
11205
+
11206
+ // final logit soft-capping
11207
+ cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
11208
+ cur = ggml_tanh(ctx0, cur);
11209
+ cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
11210
+
11188
11211
cb(cur, "result_output", -1);
11189
11212
11190
11213
ggml_build_forward_expand(gf, cur);
@@ -17709,6 +17732,12 @@ struct llama_context * llama_new_context_with_model(
17709
17732
params.flash_attn = false;
17710
17733
}
17711
17734
17735
+ if (params.flash_attn && model->hparams.attn_soft_cap) {
17736
+ LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
17737
+ params.flash_attn = false;
17738
+ }
17739
+
17740
+
17712
17741
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
17713
17742
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
17714
17743
params.flash_attn = false;
0 commit comments