Skip to content

Commit e661170

Browse files
jmorgancacompilade
authored andcommitted
llama : add support for llama 3.1 rope scaling factors (ggml-org#8676)
* Add llama 3.1 rope scaling factors to llama conversion and inference This commit generates the rope factors on conversion and adds them to the resulting model as a tensor. At inference time, these factors are passed to the `ggml_rope_ext` rope oepration, improving results for context windows above 8192 * Update convert_hf_to_gguf.py Co-authored-by: compilade <[email protected]> * address comments * address comments * Update src/llama.cpp Co-authored-by: compilade <[email protected]> * Update convert_hf_to_gguf.py Co-authored-by: compilade <[email protected]> --------- Co-authored-by: compilade <[email protected]>
1 parent 67501cf commit e661170

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

convert_hf_to_gguf.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,6 +1570,34 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
15701570
return [(self.map_tensor_name(name), data_torch)]
15711571

15721572
def prepare_tensors(self):
1573+
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
1574+
if rope_scaling.get("rope_type", '').lower() == "llama3":
1575+
base = self.hparams.get("rope_theta", 10000.0)
1576+
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
1577+
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
1578+
1579+
factor = rope_scaling.get("factor", 8.0)
1580+
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
1581+
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
1582+
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
1583+
1584+
low_freq_wavelen = old_context_len / low_freq_factor
1585+
high_freq_wavelen = old_context_len / high_freq_factor
1586+
assert low_freq_wavelen != high_freq_wavelen
1587+
1588+
rope_factors = []
1589+
for freq in freqs:
1590+
wavelen = 2 * math.pi / freq
1591+
if wavelen < high_freq_wavelen:
1592+
rope_factors.append(1)
1593+
elif wavelen > low_freq_wavelen:
1594+
rope_factors.append(factor)
1595+
else:
1596+
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
1597+
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
1598+
1599+
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
1600+
15731601
super().prepare_tensors()
15741602

15751603
if self._experts is not None:

src/llama.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2451,6 +2451,7 @@ struct llama_layer {
24512451
// long rope factors
24522452
struct ggml_tensor * rope_long = nullptr;
24532453
struct ggml_tensor * rope_short = nullptr;
2454+
struct ggml_tensor * rope_freqs = nullptr;
24542455

24552456
// bitnet scale
24562457
struct ggml_tensor * wq_scale;
@@ -6060,6 +6061,8 @@ static bool llm_load_tensors(
60606061

60616062
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
60626063

6064+
layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
6065+
60636066
if (n_expert == 0) {
60646067
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
60656068
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
@@ -8537,6 +8540,10 @@ struct llm_build_context {
85378540
// choose long/short freq factors based on the context size
85388541
const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
85398542

8543+
if (model.layers[il].rope_freqs != nullptr) {
8544+
return model.layers[il].rope_freqs;
8545+
}
8546+
85408547
if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) {
85418548
return model.layers[il].rope_long;
85428549
}
@@ -8731,6 +8738,9 @@ struct llm_build_context {
87318738

87328739
// self-attention
87338740
{
8741+
// rope freq factors for llama3; may return nullptr for llama2 and other models
8742+
struct ggml_tensor * rope_factors = build_rope_factors(il);
8743+
87348744
// compute Q and K and RoPE them
87358745
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
87368746
cb(Qcur, "Qcur", il);
@@ -8754,14 +8764,14 @@ struct llm_build_context {
87548764
}
87558765

87568766
Qcur = ggml_rope_ext(
8757-
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
8767+
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
87588768
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
87598769
ext_factor, attn_factor, beta_fast, beta_slow
87608770
);
87618771
cb(Qcur, "Qcur", il);
87628772

87638773
Kcur = ggml_rope_ext(
8764-
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
8774+
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
87658775
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
87668776
ext_factor, attn_factor, beta_fast, beta_slow
87678777
);

0 commit comments

Comments
 (0)