Skip to content

Commit f125b8d

Browse files
authored
llama : add PLM GGUF Conversion & Inference Support (#12457)
* add edgellm model arch[conversation feature doesn't work] * remove output.weight layer for edgellm arch * [Model] update the name of the model * update the name of model arch in convert gguf * [Model] Refarctor the model arch into llama-model * [Bug] Fix the bug in create attn kv * [Code] Fix editorconfig erros * [Code] Remove Trailing whitespace * [Code] Remove Trailing whitespace * [Code] Change the order of model arch in list * [Code] Fix flake8 Lint errors * Remove trailing white space * [Code] Remove call in model arch
1 parent 953c2a6 commit f125b8d

File tree

6 files changed

+274
-0
lines changed

6 files changed

+274
-0
lines changed

convert_hf_to_gguf.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4419,6 +4419,29 @@ def prepare_tensors(self):
44194419
raise ValueError(f"Unprocessed experts: {experts}")
44204420

44214421

4422+
@Model.register("PLMForCausalLM")
4423+
class PLMModel(Model):
4424+
model_arch = gguf.MODEL_ARCH.PLM
4425+
4426+
def set_vocab(self):
4427+
self._set_vocab_gpt2()
4428+
4429+
def set_gguf_parameters(self):
4430+
super().set_gguf_parameters()
4431+
hparams = self.hparams
4432+
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
4433+
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
4434+
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
4435+
self.gguf_writer.add_value_length(hparams["v_head_dim"])
4436+
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
4437+
4438+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4439+
return [(self.map_tensor_name(name), data_torch)]
4440+
4441+
def prepare_tensors(self):
4442+
super().prepare_tensors()
4443+
4444+
44224445
@Model.register("T5WithLMHeadModel")
44234446
@Model.register("T5ForConditionalGeneration")
44244447
@Model.register("MT5ForConditionalGeneration")

gguf-py/gguf/constants.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ class MODEL_ARCH(IntEnum):
286286
GRANITE_MOE = auto()
287287
CHAMELEON = auto()
288288
WAVTOKENIZER_DEC = auto()
289+
PLM = auto()
289290

290291

291292
class MODEL_TENSOR(IntEnum):
@@ -488,6 +489,7 @@ class MODEL_TENSOR(IntEnum):
488489
MODEL_ARCH.GRANITE_MOE: "granitemoe",
489490
MODEL_ARCH.CHAMELEON: "chameleon",
490491
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
492+
MODEL_ARCH.PLM: "plm",
491493
}
492494

493495
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -1464,6 +1466,20 @@ class MODEL_TENSOR(IntEnum):
14641466
MODEL_TENSOR.FFN_UP_SHEXP,
14651467
MODEL_TENSOR.FFN_EXP_PROBS_B,
14661468
],
1469+
MODEL_ARCH.PLM: [
1470+
MODEL_TENSOR.TOKEN_EMBD,
1471+
MODEL_TENSOR.OUTPUT,
1472+
MODEL_TENSOR.OUTPUT_NORM,
1473+
MODEL_TENSOR.ATTN_NORM,
1474+
MODEL_TENSOR.ATTN_Q,
1475+
MODEL_TENSOR.ATTN_KV_A_MQA,
1476+
MODEL_TENSOR.ATTN_KV_A_NORM,
1477+
MODEL_TENSOR.ATTN_KV_B,
1478+
MODEL_TENSOR.ATTN_OUT,
1479+
MODEL_TENSOR.FFN_NORM,
1480+
MODEL_TENSOR.FFN_UP,
1481+
MODEL_TENSOR.FFN_DOWN,
1482+
],
14671483
MODEL_ARCH.CHATGLM : [
14681484
MODEL_TENSOR.TOKEN_EMBD,
14691485
MODEL_TENSOR.ROPE_FREQS,

src/llama-arch.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
6565
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
6666
{ LLM_ARCH_CHAMELEON, "chameleon" },
6767
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
68+
{ LLM_ARCH_PLM, "plm" },
6869
{ LLM_ARCH_UNKNOWN, "(unknown)" },
6970
};
7071

@@ -1043,6 +1044,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
10431044
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
10441045
},
10451046
},
1047+
{
1048+
LLM_ARCH_PLM,
1049+
{
1050+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1051+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1052+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1053+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1054+
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
1055+
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
1056+
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
1057+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1058+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1059+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1060+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1061+
},
1062+
},
10461063
{
10471064
LLM_ARCH_CHATGLM,
10481065
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ enum llm_arch {
6969
LLM_ARCH_GRANITE_MOE,
7070
LLM_ARCH_CHAMELEON,
7171
LLM_ARCH_WAVTOKENIZER_DEC,
72+
LLM_ARCH_PLM,
7273
LLM_ARCH_UNKNOWN,
7374
};
7475

src/llama-model.cpp

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ const char * llm_type_name(llm_type type) {
4747
case LLM_TYPE_1_4B: return "1.4B";
4848
case LLM_TYPE_1_5B: return "1.5B";
4949
case LLM_TYPE_1_6B: return "1.6B";
50+
case LLM_TYPE_1_8B: return "1.8B";
5051
case LLM_TYPE_2B: return "2B";
5152
case LLM_TYPE_2_8B: return "2.8B";
5253
case LLM_TYPE_2_9B: return "2.9B";
@@ -1144,6 +1145,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
11441145
default: type = LLM_TYPE_UNKNOWN;
11451146
}
11461147
} break;
1148+
case LLM_ARCH_PLM:
1149+
{
1150+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1151+
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
1152+
switch (hparams.n_layer) {
1153+
case 32: type = LLM_TYPE_1_8B; break;
1154+
default: type = LLM_TYPE_UNKNOWN;
1155+
}
1156+
} break;
11471157
case LLM_ARCH_CHATGLM:
11481158
{
11491159
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -3068,6 +3078,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
30683078
}
30693079
}
30703080
} break;
3081+
case LLM_ARCH_PLM:
3082+
{
3083+
const int64_t n_embd_head_qk_rope = hparams.n_rot;
3084+
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
3085+
const int64_t kv_lora_rank = hparams.n_lora_kv;
3086+
3087+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
3088+
3089+
// output
3090+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
3091+
// output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
3092+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
3093+
3094+
for (int i = 0; i < n_layer; ++i) {
3095+
auto & layer = layers[i];
3096+
3097+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
3098+
3099+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
3100+
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
3101+
layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
3102+
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
3103+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0);
3104+
3105+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
3106+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
3107+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
3108+
}
3109+
} break;
30713110
case LLM_ARCH_BITNET:
30723111
{
30733112
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -11615,6 +11654,178 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context {
1161511654
}
1161611655
};
1161711656

11657+
struct llm_build_plm : public llm_graph_context {
11658+
llm_build_plm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
11659+
const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k));
11660+
11661+
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
11662+
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
11663+
const uint32_t kv_lora_rank = hparams.n_lora_kv;
11664+
11665+
ggml_tensor * cur;
11666+
ggml_tensor * inpL;
11667+
11668+
// {n_embd, n_tokens}
11669+
inpL = build_inp_embd(model.tok_embd);
11670+
11671+
// inp_pos - contains the positions
11672+
ggml_tensor * inp_pos = build_inp_pos();
11673+
11674+
auto * inp_attn = build_attn_inp_kv_unified();
11675+
11676+
for (int il = 0; il < n_layer; ++il) {
11677+
ggml_tensor * inpSA = inpL;
11678+
11679+
// norm
11680+
cur = build_norm(inpL,
11681+
model.layers[il].attn_norm, NULL,
11682+
LLM_NORM_RMS, il);
11683+
cb(cur, "attn_norm", il);
11684+
11685+
// self_attention
11686+
{
11687+
ggml_tensor * q = NULL;
11688+
q = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
11689+
cb(q, "q", il);
11690+
11691+
// split into {n_head * n_embd_head_qk_nope, n_tokens}
11692+
ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
11693+
ggml_row_size(q->type, hparams.n_embd_head_k),
11694+
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
11695+
0);
11696+
cb(q_nope, "q_nope", il);
11697+
11698+
// and {n_head * n_embd_head_qk_rope, n_tokens}
11699+
ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
11700+
ggml_row_size(q->type, hparams.n_embd_head_k),
11701+
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
11702+
ggml_row_size(q->type, n_embd_head_qk_nope));
11703+
cb(q_pe, "q_pe", il);
11704+
11705+
// {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
11706+
ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
11707+
cb(kv_pe_compresseed, "kv_pe_compresseed", il);
11708+
11709+
// split into {kv_lora_rank, n_tokens}
11710+
ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
11711+
kv_pe_compresseed->nb[1],
11712+
0);
11713+
cb(kv_compressed, "kv_compressed", il);
11714+
11715+
// and {n_embd_head_qk_rope, n_tokens}
11716+
ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
11717+
kv_pe_compresseed->nb[1],
11718+
kv_pe_compresseed->nb[1],
11719+
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
11720+
cb(k_pe, "k_pe", il);
11721+
11722+
kv_compressed = build_norm(kv_compressed,
11723+
model.layers[il].attn_kv_a_norm, NULL,
11724+
LLM_NORM_RMS, il);
11725+
cb(kv_compressed, "kv_compressed", il);
11726+
11727+
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
11728+
ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
11729+
cb(kv, "kv", il);
11730+
11731+
// split into {n_head * n_embd_head_qk_nope, n_tokens}
11732+
ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
11733+
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
11734+
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
11735+
0);
11736+
cb(k_nope, "k_nope", il);
11737+
11738+
// and {n_head * n_embd_head_v, n_tokens}
11739+
ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
11740+
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
11741+
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
11742+
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
11743+
cb(v_states, "v_states", il);
11744+
11745+
v_states = ggml_cont(ctx0, v_states);
11746+
cb(v_states, "v_states", il);
11747+
11748+
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
11749+
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
11750+
0);
11751+
cb(v_states, "v_states", il);
11752+
11753+
q_pe = ggml_rope_ext(
11754+
ctx0, q_pe, inp_pos, nullptr,
11755+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
11756+
ext_factor, attn_factor, beta_fast, beta_slow
11757+
);
11758+
cb(q_pe, "q_pe", il);
11759+
11760+
// shared RoPE key
11761+
k_pe = ggml_rope_ext(
11762+
ctx0, k_pe, inp_pos, nullptr,
11763+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
11764+
ext_factor, attn_factor, beta_fast, beta_slow
11765+
);
11766+
cb(k_pe, "k_pe", il);
11767+
11768+
ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
11769+
cb(q_states, "q_states", il);
11770+
11771+
ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
11772+
cb(k_states, "k_states", il);
11773+
11774+
cur = build_attn(inp_attn, gf,
11775+
model.layers[il].wo, NULL,
11776+
q_states, k_states, v_states, nullptr, kq_scale, il);
11777+
}
11778+
11779+
if (il == n_layer - 1) {
11780+
// skip computing output for unused tokens
11781+
ggml_tensor * inp_out_ids = build_inp_out_ids();
11782+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11783+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11784+
}
11785+
11786+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
11787+
cb(ffn_inp, "ffn_inp", il);
11788+
11789+
cur = build_norm(ffn_inp,
11790+
model.layers[il].ffn_norm, NULL,
11791+
LLM_NORM_RMS, il);
11792+
cb(cur, "ffn_norm", il);
11793+
11794+
cur = build_ffn(cur,
11795+
model.layers[il].ffn_up, NULL, NULL,
11796+
NULL, NULL, NULL,
11797+
model.layers[il].ffn_down, NULL, NULL,
11798+
NULL,
11799+
LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
11800+
cb(cur, "ffn_out", il);
11801+
11802+
cur = ggml_add(ctx0, cur, ffn_inp);
11803+
11804+
cur = build_cvec(cur, il);
11805+
cb(cur, "l_out", il);
11806+
11807+
// input for next layer
11808+
inpL = cur;
11809+
}
11810+
11811+
cur = inpL;
11812+
11813+
cur = build_norm(cur,
11814+
model.output_norm, NULL,
11815+
LLM_NORM_RMS, -1);
11816+
11817+
cb(cur, "result_norm", -1);
11818+
res->t_embd = cur;
11819+
11820+
cur = build_lora_mm(model.output, cur);
11821+
11822+
cb(cur, "result_output", -1);
11823+
res->t_logits = cur;
11824+
11825+
ggml_build_forward_expand(gf, cur);
11826+
}
11827+
};
11828+
1161811829
llama_memory_i * llama_model::create_memory() const {
1161911830
llama_memory_i * res;
1162011831

@@ -11887,6 +12098,10 @@ llm_graph_result_ptr llama_model::build_graph(
1188712098
{
1188812099
llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params, gf);
1188912100
} break;
12101+
case LLM_ARCH_PLM:
12102+
{
12103+
llm = std::make_unique<llm_build_plm>(*this, params, gf);
12104+
} break;
1189012105
default:
1189112106
GGML_ABORT("fatal error");
1189212107
}
@@ -12013,6 +12228,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1201312228
case LLM_ARCH_ARCTIC:
1201412229
case LLM_ARCH_DEEPSEEK:
1201512230
case LLM_ARCH_DEEPSEEK2:
12231+
case LLM_ARCH_PLM:
1201612232
case LLM_ARCH_CHATGLM:
1201712233
case LLM_ARCH_GRANITE:
1201812234
case LLM_ARCH_GRANITE_MOE:

src/llama-model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ enum llm_type {
4444
LLM_TYPE_1_4B,
4545
LLM_TYPE_1_5B,
4646
LLM_TYPE_1_6B,
47+
LLM_TYPE_1_8B,
4748
LLM_TYPE_2B,
4849
LLM_TYPE_2_8B,
4950
LLM_TYPE_2_9B,

0 commit comments

Comments
 (0)