Skip to content

Commit 0d2ec43

Browse files
authored
llama : support IBM Granite architecture (#9412)
* feat(gguf-py): Add Granite model and params to gguf-py Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * feat(convert_hf_to_gguf): Add registration and param setup for Granite Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * feat(llama.cpp): Add config parsing for Granite multiplier params Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * feat(llama.cpp): First pass at full port of granite deviations from llama Something is still not working right since the results are mostly terrible, but on occasion it's producing relevant results at this point, so _something_ is working. Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(llama.cpp): Determine granite language 3b instruct by vocab size Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(convert_hf_to_gguf): Use LlamaModel as base for GraniteModel The defaults in LlamaModel are needed for Granite as well Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(llama.cpp): Switch Granite param names to use _scale for consistency Other scalar multipliers are called *_scale, so this provides a more consistent naming convention. Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(convert_hf_to_gguf/gguf-py): _multiplier -> _scale The transformers names with _multiplier will now be converted to the _scale equivalent during conversion. Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(llama.cpp): Use separate switch clause for granite in llm_load_hparams Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> --------- Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 37f3a38 commit 0d2ec43

File tree

4 files changed

+135
-1
lines changed

4 files changed

+135
-1
lines changed

convert_hf_to_gguf.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4080,6 +4080,36 @@ def prepare_tensors(self):
40804080
super().prepare_tensors()
40814081

40824082

4083+
@Model.register("GraniteForCausalLM")
4084+
class GraniteModel(LlamaModel):
4085+
"""Conversion for IBM's GraniteForCausalLM"""
4086+
model_arch = gguf.MODEL_ARCH.GRANITE
4087+
4088+
def set_gguf_parameters(self):
4089+
"""Granite uses standard llama parameters with the following differences:
4090+
4091+
- No head_dim support
4092+
- New multiplier params:
4093+
- attention_scale
4094+
- embedding_scale
4095+
- residual_scale
4096+
- logits_scaling
4097+
"""
4098+
if head_dim := self.hparams.pop("head_dim", None):
4099+
logger.warning("Ignoring head_dim (%s) from config for Granite", head_dim)
4100+
super().set_gguf_parameters()
4101+
# NOTE: Convert _multiplier params to _scale params for naming
4102+
# consistency
4103+
if attention_scale := self.hparams.get("attention_multiplier"):
4104+
self.gguf_writer.add_attention_scale(attention_scale)
4105+
if embedding_scale := self.hparams.get("embedding_multiplier"):
4106+
self.gguf_writer.add_embedding_scale(embedding_scale)
4107+
if residual_scale := self.hparams.get("residual_multiplier"):
4108+
self.gguf_writer.add_residual_scale(residual_scale)
4109+
if logits_scaling := self.hparams.get("logits_scaling"):
4110+
self.gguf_writer.add_logit_scale(logits_scaling)
4111+
4112+
40834113
###### CONVERSION LOGIC ######
40844114

40854115
# tree of lazy tensors

gguf-py/gguf/constants.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ class LLM:
9797
RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
9898
TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim"
9999
TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim"
100+
RESIDUAL_SCALE = "{arch}.residual_scale"
101+
EMBEDDING_SCALE = "{arch}.embedding_scale"
100102

101103
class Attention:
102104
HEAD_COUNT = "{arch}.attention.head_count"
@@ -112,6 +114,7 @@ class Attention:
112114
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
113115
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
114116
SLIDING_WINDOW = "{arch}.attention.sliding_window"
117+
SCALE = "{arch}.attention.scale"
115118

116119
class Rope:
117120
DIMENSION_COUNT = "{arch}.rope.dimension_count"
@@ -231,6 +234,7 @@ class MODEL_ARCH(IntEnum):
231234
JAIS = auto()
232235
NEMOTRON = auto()
233236
EXAONE = auto()
237+
GRANITE = auto()
234238

235239

236240
class MODEL_TENSOR(IntEnum):
@@ -387,6 +391,7 @@ class MODEL_TENSOR(IntEnum):
387391
MODEL_ARCH.JAIS: "jais",
388392
MODEL_ARCH.NEMOTRON: "nemotron",
389393
MODEL_ARCH.EXAONE: "exaone",
394+
MODEL_ARCH.GRANITE: "granite",
390395
}
391396

392397
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -1224,6 +1229,19 @@ class MODEL_TENSOR(IntEnum):
12241229
MODEL_TENSOR.FFN_DOWN,
12251230
MODEL_TENSOR.FFN_UP,
12261231
],
1232+
MODEL_ARCH.GRANITE: [
1233+
MODEL_TENSOR.TOKEN_EMBD,
1234+
MODEL_TENSOR.OUTPUT_NORM,
1235+
MODEL_TENSOR.ATTN_NORM,
1236+
MODEL_TENSOR.ATTN_Q,
1237+
MODEL_TENSOR.ATTN_K,
1238+
MODEL_TENSOR.ATTN_V,
1239+
MODEL_TENSOR.ATTN_OUT,
1240+
MODEL_TENSOR.FFN_NORM,
1241+
MODEL_TENSOR.FFN_GATE,
1242+
MODEL_TENSOR.FFN_DOWN,
1243+
MODEL_TENSOR.FFN_UP,
1244+
],
12271245
# TODO
12281246
}
12291247

gguf-py/gguf/gguf_writer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,12 @@ def add_time_mix_extra_dim(self, dim: int) -> None:
679679
def add_time_decay_extra_dim(self, dim: int) -> None:
680680
self.add_uint32(Keys.LLM.TIME_DECAY_EXTRA_DIM.format(arch=self.arch), dim)
681681

682+
def add_residual_scale(self, value: float) -> None:
683+
self.add_float32(Keys.LLM.RESIDUAL_SCALE.format(arch=self.arch), value)
684+
685+
def add_embedding_scale(self, value: float) -> None:
686+
self.add_float32(Keys.LLM.EMBEDDING_SCALE.format(arch=self.arch), value)
687+
682688
def add_wkv_head_size(self, size: int) -> None:
683689
self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
684690

@@ -703,6 +709,9 @@ def add_relative_attn_buckets_count(self, value: int) -> None:
703709
def add_sliding_window(self, value: int) -> None:
704710
self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
705711

712+
def add_attention_scale(self, value: float) -> None:
713+
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
714+
706715
def add_pooling_type(self, value: PoolingType) -> None:
707716
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
708717

src/llama.cpp

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ enum llm_arch {
214214
LLM_ARCH_NEMOTRON,
215215
LLM_ARCH_EXAONE,
216216
LLM_ARCH_RWKV6,
217+
LLM_ARCH_GRANITE,
217218
LLM_ARCH_UNKNOWN,
218219
};
219220

@@ -264,6 +265,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
264265
{ LLM_ARCH_NEMOTRON, "nemotron" },
265266
{ LLM_ARCH_EXAONE, "exaone" },
266267
{ LLM_ARCH_RWKV6, "rwkv6" },
268+
{ LLM_ARCH_GRANITE, "granite" },
267269
{ LLM_ARCH_UNKNOWN, "(unknown)" },
268270
};
269271

@@ -303,6 +305,8 @@ enum llm_kv {
303305
LLM_KV_RESCALE_EVERY_N_LAYERS,
304306
LLM_KV_TIME_MIX_EXTRA_DIM,
305307
LLM_KV_TIME_DECAY_EXTRA_DIM,
308+
LLM_KV_RESIDUAL_SCALE,
309+
LLM_KV_EMBEDDING_SCALE,
306310

307311
LLM_KV_ATTENTION_HEAD_COUNT,
308312
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -317,6 +321,7 @@ enum llm_kv {
317321
LLM_KV_ATTENTION_KV_LORA_RANK,
318322
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
319323
LLM_KV_ATTENTION_SLIDING_WINDOW,
324+
LLM_KV_ATTENTION_SCALE,
320325

321326
LLM_KV_ROPE_DIMENSION_COUNT,
322327
LLM_KV_ROPE_FREQ_BASE,
@@ -407,6 +412,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
407412
{ LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" },
408413
{ LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" },
409414
{ LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" },
415+
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
416+
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
410417

411418
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
412419
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -421,6 +428,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
421428
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
422429
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
423430
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
431+
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
424432

425433
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
426434
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
@@ -1454,6 +1462,22 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
14541462
{ LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" },
14551463
},
14561464
},
1465+
{
1466+
LLM_ARCH_GRANITE,
1467+
{
1468+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1469+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1470+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1471+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1472+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1473+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1474+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1475+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1476+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1477+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1478+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1479+
},
1480+
},
14571481
{
14581482
LLM_ARCH_UNKNOWN,
14591483
{
@@ -2372,6 +2396,11 @@ struct llama_hparams {
23722396
float f_max_alibi_bias = 0.0f;
23732397
float f_logit_scale = 0.0f;
23742398

2399+
// Additional scale factors (Granite)
2400+
float f_residual_scale = 0.0f;
2401+
float f_embedding_scale = 0.0f;
2402+
float f_attention_scale = 0.0f;
2403+
23752404
bool causal_attn = true;
23762405
bool use_alibi = false;
23772406
bool attn_soft_cap = false;
@@ -2434,6 +2463,9 @@ struct llama_hparams {
24342463
if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true;
24352464
if (!is_float_close(this->expert_weights_scale, other.expert_weights_scale, EPSILON)) return true;
24362465
if (!is_float_close(this->rope_yarn_log_mul, other.rope_yarn_log_mul, EPSILON)) return true;
2466+
if (!is_float_close(this->f_residual_scale, other.f_residual_scale, EPSILON)) return true;
2467+
if (!is_float_close(this->f_embedding_scale, other.f_embedding_scale, EPSILON)) return true;
2468+
if (!is_float_close(this->f_attention_scale, other.f_attention_scale, EPSILON)) return true;
24372469

24382470
return false;
24392471
}
@@ -6019,6 +6051,20 @@ static void llm_load_hparams(
60196051
default: model.type = e_model::MODEL_UNKNOWN;
60206052
}
60216053
} break;
6054+
case LLM_ARCH_GRANITE:
6055+
{
6056+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
6057+
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
6058+
ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale);
6059+
ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale);
6060+
ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
6061+
6062+
switch (hparams.n_layer) {
6063+
case 40: model.type = e_model::MODEL_3B; break;
6064+
// Add additional layer/vocab/etc checks here for other model sizes
6065+
default: model.type = e_model::MODEL_UNKNOWN;
6066+
}
6067+
} break;
60226068
default: (void)0;
60236069
}
60246070

@@ -6717,6 +6763,12 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
67176763
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
67186764
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
67196765
}
6766+
6767+
if (model.arch == LLM_ARCH_GRANITE) {
6768+
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
6769+
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
6770+
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
6771+
}
67206772
}
67216773

67226774
// Returns false if cancelled by progress_callback
@@ -6885,6 +6937,7 @@ static bool llm_load_tensors(
68856937
case LLM_ARCH_LLAMA:
68866938
case LLM_ARCH_REFACT:
68876939
case LLM_ARCH_MINICPM:
6940+
case LLM_ARCH_GRANITE:
68886941
{
68896942
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
68906943

@@ -8868,6 +8921,11 @@ static struct ggml_tensor * llm_build_inp_embd(
88688921
ggml_set_input(lctx.inp_embd);
88698922
}
88708923

8924+
// For Granite architecture
8925+
if (hparams.f_embedding_scale != 0.0f) {
8926+
inpL = ggml_scale(ctx, inpL, hparams.f_embedding_scale);
8927+
}
8928+
88718929
cb(inpL, "inp_embd", -1);
88728930

88738931
return inpL;
@@ -10146,6 +10204,7 @@ struct llm_build_context {
1014610204
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
1014710205
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
1014810206

10207+
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
1014910208
for (int il = 0; il < n_layer; ++il) {
1015010209
struct ggml_tensor * inpSA = inpL;
1015110210

@@ -10198,7 +10257,7 @@ struct llm_build_context {
1019810257

1019910258
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
1020010259
model.layers[il].wo, model.layers[il].bo,
10201-
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
10260+
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
1020210261
}
1020310262

1020410263
if (il == n_layer - 1) {
@@ -10209,6 +10268,11 @@ struct llm_build_context {
1020910268
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
1021010269
}
1021110270

10271+
// For Granite architecture
10272+
if (hparams.f_residual_scale) {
10273+
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
10274+
}
10275+
1021210276
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
1021310277
cb(ffn_inp, "ffn_inp", il);
1021410278

@@ -10245,6 +10309,11 @@ struct llm_build_context {
1024510309
cb(cur, "ffn_moe_out", il);
1024610310
}
1024710311

10312+
// For Granite architecture
10313+
if (hparams.f_residual_scale) {
10314+
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
10315+
}
10316+
1024810317
cur = ggml_add(ctx0, cur, ffn_inp);
1024910318
cb(cur, "ffn_out", il);
1025010319

@@ -10264,6 +10333,12 @@ struct llm_build_context {
1026410333

1026510334
// lm_head
1026610335
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
10336+
10337+
// For Granite architecture
10338+
if (hparams.f_logit_scale) {
10339+
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
10340+
}
10341+
1026710342
cb(cur, "result_output", -1);
1026810343

1026910344
ggml_build_forward_expand(gf, cur);
@@ -15789,6 +15864,7 @@ static struct ggml_cgraph * llama_build_graph(
1578915864

1579015865
switch (model.arch) {
1579115866
case LLM_ARCH_LLAMA:
15867+
case LLM_ARCH_GRANITE:
1579215868
{
1579315869
result = llm.build_llama();
1579415870
} break;
@@ -19089,6 +19165,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1908919165
case LLM_ARCH_ARCTIC:
1909019166
case LLM_ARCH_DEEPSEEK2:
1909119167
case LLM_ARCH_CHATGLM:
19168+
case LLM_ARCH_GRANITE:
1909219169
return LLAMA_ROPE_TYPE_NORM;
1909319170

1909419171
// the pairs of head values are offset by n_rot/2

0 commit comments

Comments
 (0)