Skip to content

Commit 7269067

Browse files
committed
Add llama 3.1 rope scaling factors to llama conversion and inference
This commit generates the rope factors on conversion and adds them to the resulting model as a tensor. At inference time, these factors are passed to the `ggml_rope_ext` rope oepration, improving results for context windows above 8192
1 parent 68504f0 commit 7269067

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

convert_hf_to_gguf.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,6 +1514,35 @@ def set_gguf_parameters(self):
15141514
if self.hparams.get("vocab_size", 32000) == 49152:
15151515
self.gguf_writer.add_add_bos_token(False)
15161516

1517+
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
1518+
if rope_scaling.get("rope_type", '').lower() == "llama3":
1519+
base = hparams.get("rope_theta", 10000.0)
1520+
dim = int((hparams["hidden_size"] // hparams["num_attention_heads"]) * hparams.get("partial_rotary_embeddings", 1.0))
1521+
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
1522+
1523+
factor = rope_scaling.get("factor", 8.0)
1524+
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
1525+
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
1526+
old_context_len = hparams.get("original_max_position_embeddings", 8192)
1527+
1528+
low_freq_wavelen = old_context_len / low_freq_factor
1529+
high_freq_wavelen = old_context_len / high_freq_factor
1530+
1531+
rope_factors = []
1532+
for freq in freqs:
1533+
wavelen = 2 * math.pi / freq
1534+
if wavelen < high_freq_wavelen:
1535+
rope_factors.append(1)
1536+
elif wavelen > low_freq_wavelen:
1537+
rope_factors.append(factor)
1538+
else:
1539+
assert low_freq_wavelen != high_freq_wavelen
1540+
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
1541+
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
1542+
1543+
self.gguf_writer.add_rope_scaling_attn_factors(1.0)
1544+
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FREQS] + ".weight", np.array(rope_factors, dtype=np.float32))
1545+
15171546
@staticmethod
15181547
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
15191548
if n_head_kv is not None and n_head != n_head_kv:

src/llama.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2455,6 +2455,7 @@ struct llama_layer {
24552455
// long rope factors
24562456
struct ggml_tensor * rope_long = nullptr;
24572457
struct ggml_tensor * rope_short = nullptr;
2458+
struct ggml_tensor * rope_freqs = nullptr;
24582459

24592460
// bitnet scale
24602461
struct ggml_tensor * wq_scale;
@@ -6055,6 +6056,8 @@ static bool llm_load_tensors(
60556056

60566057
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
60576058

6059+
layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), { n_embd/n_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
6060+
60586061
if (n_expert == 0) {
60596062
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
60606063
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
@@ -8532,6 +8535,10 @@ struct llm_build_context {
85328535
// choose long/short freq factors based on the context size
85338536
const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
85348537

8538+
if (model.layers[il].rope_freqs != nullptr) {
8539+
return model.layers[il].rope_freqs;
8540+
}
8541+
85358542
if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) {
85368543
return model.layers[il].rope_long;
85378544
}
@@ -8726,6 +8733,9 @@ struct llm_build_context {
87268733

87278734
// self-attention
87288735
{
8736+
// rope freq factors for llama3; may return nullptr for llama2 and other models
8737+
struct ggml_tensor * rope_factors = build_rope_factors(il);
8738+
87298739
// compute Q and K and RoPE them
87308740
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
87318741
cb(Qcur, "Qcur", il);
@@ -8749,14 +8759,14 @@ struct llm_build_context {
87498759
}
87508760

87518761
Qcur = ggml_rope_ext(
8752-
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
8762+
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
87538763
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
87548764
ext_factor, attn_factor, beta_fast, beta_slow
87558765
);
87568766
cb(Qcur, "Qcur", il);
87578767

87588768
Kcur = ggml_rope_ext(
8759-
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
8769+
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
87608770
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
87618771
ext_factor, attn_factor, beta_fast, beta_slow
87628772
);

0 commit comments

Comments
 (0)