Skip to content

Commit 0211330

Browse files
committed
llama: allow to override model rope type
Signed-off-by: Giuseppe Scrivano <[email protected]>
1 parent 120f7bf commit 0211330

File tree

4 files changed

+42
-3
lines changed

4 files changed

+42
-3
lines changed

gguf-py/gguf/constants.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class Attention:
5757
CAUSAL = "{arch}.attention.causal"
5858

5959
class Rope:
60+
TYPE = "{arch}.rope.type"
6061
DIMENSION_COUNT = "{arch}.rope.dimension_count"
6162
FREQ_BASE = "{arch}.rope.freq_base"
6263
SCALING_TYPE = "{arch}.rope.scaling.type"
@@ -806,6 +807,13 @@ class TokenType(IntEnum):
806807
BYTE = 6
807808

808809

810+
class RopeType(Enum):
811+
NONE = 'none'
812+
NORM = 'norm'
813+
NEOX = 'neox'
814+
GLM = 'glm'
815+
816+
809817
class RopeScalingType(Enum):
810818
NONE = 'none'
811819
LINEAR = 'linear'
@@ -998,6 +1006,7 @@ def get_type(val: Any) -> GGUFValueType:
9981006
KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS
9991007

10001008
# RoPE
1009+
KEY_ROPE_TYPE = Keys.Rope.TYPE
10011010
KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT
10021011
KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE
10031012
KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,9 @@ def add_rope_dimension_count(self, count: int) -> None:
427427
def add_rope_freq_base(self, value: float) -> None:
428428
self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value)
429429

430+
def add_rope_type(self, value: RopeType) -> None:
431+
self.add_string(Keys.Rope.TYPE.format(arch=self.arch), value.value)
432+
430433
def add_rope_scaling_type(self, value: RopeScalingType) -> None:
431434
self.add_string(Keys.Rope.SCALING_TYPE.format(arch=self.arch), value.value)
432435

llama.cpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ enum llm_kv {
297297
LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
298298
LLM_KV_ATTENTION_CAUSAL,
299299

300+
LLM_KV_ROPE_TYPE,
300301
LLM_KV_ROPE_DIMENSION_COUNT,
301302
LLM_KV_ROPE_FREQ_BASE,
302303
LLM_KV_ROPE_SCALE_LINEAR,
@@ -375,6 +376,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
375376
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
376377
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
377378

379+
{ LLM_KV_ROPE_TYPE, "%s.rope.type" },
378380
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
379381
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
380382
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
@@ -1129,12 +1131,29 @@ struct LLM_TN {
11291131
// gguf helpers
11301132
//
11311133

1134+
static const std::map<enum llama_rope_type, const char *> LLAMA_ROPE_TYPES = {
1135+
{ LLAMA_ROPE_TYPE_NONE, "none" },
1136+
{ LLAMA_ROPE_TYPE_NORM, "norm" },
1137+
{ LLAMA_ROPE_TYPE_NEOX, "neox" },
1138+
{ LLAMA_ROPE_TYPE_GLM, "glm" },
1139+
};
1140+
11321141
static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_TYPES = {
11331142
{ LLAMA_ROPE_SCALING_TYPE_NONE, "none" },
11341143
{ LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" },
11351144
{ LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" },
11361145
};
11371146

1147+
static enum llama_rope_type llama_rope_type_from_string(const std::string & name) {
1148+
for (const auto & kv : LLAMA_ROPE_TYPES) {
1149+
if (kv.second == name) {
1150+
return (enum llama_rope_type) kv.first;
1151+
}
1152+
}
1153+
1154+
return LLAMA_ROPE_TYPE_NONE;
1155+
}
1156+
11381157
static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
11391158
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
11401159
if (kv.second == name) {
@@ -4394,7 +4413,15 @@ static void llm_load_hparams(
43944413
hparams.use_alibi = true;
43954414
}
43964415

4397-
hparams.rope_type = llama_rope_type(&model);
4416+
hparams.rope_type = llama_default_rope_type(&model);
4417+
4418+
const auto kv = LLM_KV(model.arch);
4419+
const int rope_type_keyidx = gguf_find_key(ctx, kv(LLM_KV_ROPE_TYPE).c_str());
4420+
if (rope_type_keyidx != -1) {
4421+
std::string rope_type("none");
4422+
ml.get_key(LLM_KV_ROPE_TYPE, rope_type);
4423+
hparams.rope_type = llama_rope_type_from_string(rope_type);
4424+
}
43984425
}
43994426

44004427
// TODO: This should probably be in llama.h
@@ -16216,7 +16243,7 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
1621616243
return model->vocab.type;
1621716244
}
1621816245

16219-
enum llama_rope_type llama_rope_type(const struct llama_model * model) {
16246+
enum llama_rope_type llama_default_rope_type(const struct llama_model * model) {
1622016247
switch (model->arch) {
1622116248
// these models do not use RoPE
1622216249
case LLM_ARCH_GPT2:

llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ extern "C" {
422422
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
423423

424424
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
425-
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
425+
LLAMA_API enum llama_rope_type llama_default_rope_type (const struct llama_model * model);
426426

427427
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
428428
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);

0 commit comments

Comments
 (0)