Skip to content

Commit b1f9c97

Browse files
committed
Merge branch 'upstream' into concedo_experimental
2 parents 8a07ce3 + 1c5eba6 commit b1f9c97

File tree

4 files changed

+46
-3
lines changed

4 files changed

+46
-3
lines changed

convert-hf-to-gguf.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2363,6 +2363,12 @@ def set_gguf_parameters(self):
23632363
self.gguf_writer.add_key_length(hparams["head_dim"])
23642364
self.gguf_writer.add_value_length(hparams["head_dim"])
23652365
self.gguf_writer.add_file_type(self.ftype)
2366+
self.gguf_writer.add_attn_logit_softcapping(
2367+
self.hparams["attn_logit_softcapping"]
2368+
)
2369+
self.gguf_writer.add_final_logit_softcapping(
2370+
self.hparams["final_logit_softcapping"]
2371+
)
23662372

23672373
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
23682374
del bid # unusem

gguf-py/gguf/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class LLM:
5050
POOLING_TYPE = "{arch}.pooling_type"
5151
LOGIT_SCALE = "{arch}.logit_scale"
5252
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
53+
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
54+
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
5355

5456
class Attention:
5557
HEAD_COUNT = "{arch}.attention.head_count"

gguf-py/gguf/gguf_writer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,12 @@ def add_clamp_kqv(self, value: float) -> None:
516516
def add_logit_scale(self, value: float) -> None:
517517
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
518518

519+
def add_attn_logit_softcapping(self, value: float) -> None:
520+
self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
521+
522+
def add_final_logit_softcapping(self, value: float) -> None:
523+
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
524+
519525
def add_expert_count(self, count: int) -> None:
520526
self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
521527

src/llama.cpp

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ enum llm_kv {
326326
LLM_KV_POOLING_TYPE,
327327
LLM_KV_LOGIT_SCALE,
328328
LLM_KV_DECODER_START_TOKEN_ID,
329+
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
330+
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
329331

330332
LLM_KV_ATTENTION_HEAD_COUNT,
331333
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -416,6 +418,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
416418
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
417419
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
418420
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
421+
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
422+
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
419423

420424
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
421425
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -2127,6 +2131,9 @@ struct llama_hparams {
21272131
float f_norm_eps;
21282132
float f_norm_rms_eps;
21292133

2134+
float f_attn_logit_softcapping = 50.0f;
2135+
float f_final_logit_softcapping = 30.0f;
2136+
21302137
float rope_attn_factor = 1.0f;
21312138
float rope_freq_base_train;
21322139
float rope_freq_scale_train;
@@ -2143,8 +2150,9 @@ struct llama_hparams {
21432150
float f_max_alibi_bias = 0.0f;
21442151
float f_logit_scale = 0.0f;
21452152

2146-
bool causal_attn = true;
2147-
bool use_alibi = false;
2153+
bool causal_attn = true;
2154+
bool use_alibi = false;
2155+
bool attn_soft_cap = false;
21482156

21492157
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
21502158
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
@@ -4755,6 +4763,9 @@ static void llm_load_hparams(
47554763
case LLM_ARCH_GEMMA2:
47564764
{
47574765
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
4766+
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
4767+
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
4768+
hparams.attn_soft_cap = true;
47584769

47594770
switch (hparams.n_layer) {
47604771
case 42: model.type = e_model::MODEL_9B; break;
@@ -7658,6 +7669,12 @@ static struct ggml_tensor * llm_build_kqv(
76587669
kq = ggml_scale(ctx, kq, 30);
76597670
}
76607671

7672+
if (hparams.attn_soft_cap) {
7673+
kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
7674+
kq = ggml_tanh(ctx, kq);
7675+
kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
7676+
}
7677+
76617678
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
76627679
cb(kq, "kq_soft_max_ext", il);
76637680

@@ -11118,7 +11135,7 @@ struct llm_build_context {
1111811135
ext_factor, attn_factor, beta_fast, beta_slow);
1111911136
cb(Qcur, "Qcur", il);
1112011137

11121-
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
11138+
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head)));
1112211139
cb(Qcur, "Qcur_scaled", il);
1112311140

1112411141
Kcur = ggml_rope_ext(
@@ -11185,6 +11202,12 @@ struct llm_build_context {
1118511202

1118611203
// lm_head
1118711204
cur = ggml_mul_mat(ctx0, model.output, cur);
11205+
11206+
// final logit soft-capping
11207+
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
11208+
cur = ggml_tanh(ctx0, cur);
11209+
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
11210+
1118811211
cb(cur, "result_output", -1);
1118911212

1119011213
ggml_build_forward_expand(gf, cur);
@@ -17709,6 +17732,12 @@ struct llama_context * llama_new_context_with_model(
1770917732
params.flash_attn = false;
1771017733
}
1771117734

17735+
if (params.flash_attn && model->hparams.attn_soft_cap) {
17736+
LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
17737+
params.flash_attn = false;
17738+
}
17739+
17740+
1771217741
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
1771317742
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
1771417743
params.flash_attn = false;

0 commit comments

Comments
 (0)