Skip to content

Commit 8d45a95

Browse files
runfuturehodlen
authored andcommitted
llama : add MiniCPM support (ggml-org#5346)
* support minicpm arch. * fix tab/space typo. * convert minicpm model via convert-hf-gguf.py * try to make tokenizer work * fix bug for quantize minicpm * fix for flake8 lint * remove convert-minicpm.py * fix for editorconfig * correct minicpm model type (size) * constants expanded for minicpm * Minor change of the constant names for minicpm
1 parent 7808c1b commit 8d45a95

File tree

3 files changed

+259
-1
lines changed

3 files changed

+259
-1
lines changed

convert-hf-to-gguf.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
2323
import gguf
2424

25+
from convert import HfVocab
26+
2527

2628
# check for any of the given keys in the dictionary and return the value of the first key found
2729
def get_key_opts(d, keys):
@@ -205,6 +207,8 @@ def from_model_architecture(model_architecture):
205207
return OrionModel
206208
if model_architecture == "InternLM2ForCausalLM":
207209
return InternLM2Model
210+
if model_architecture == "MiniCPMForCausalLM":
211+
return MiniCPMModel
208212
return Model
209213

210214
def _is_model_safetensors(self) -> bool:
@@ -258,6 +262,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
258262
return gguf.MODEL_ARCH.ORION
259263
if arch == "InternLM2ForCausalLM":
260264
return gguf.MODEL_ARCH.INTERNLM2
265+
if arch == "MiniCPMForCausalLM":
266+
return gguf.MODEL_ARCH.MINICPM
261267

262268
raise NotImplementedError(f'Architecture "{arch}" not supported!')
263269

@@ -402,6 +408,31 @@ def _set_vocab_sentencepiece(self):
402408
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
403409
special_vocab.add_to_gguf(self.gguf_writer)
404410

411+
def _set_vocab_hf(self):
412+
path = self.dir_model
413+
added_tokens_path = self.dir_model
414+
vocab = HfVocab(
415+
path, added_tokens_path if added_tokens_path.exists() else None
416+
)
417+
tokens = []
418+
scores = []
419+
toktypes = []
420+
421+
for text, score, toktype in vocab.all_tokens():
422+
tokens.append(text)
423+
scores.append(score)
424+
toktypes.append(toktype)
425+
426+
assert len(tokens) == vocab.vocab_size
427+
428+
self.gguf_writer.add_tokenizer_model("llama")
429+
self.gguf_writer.add_token_list(tokens)
430+
self.gguf_writer.add_token_scores(scores)
431+
self.gguf_writer.add_token_types(toktypes)
432+
433+
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
434+
special_vocab.add_to_gguf(self.gguf_writer)
435+
405436

406437
class GPTNeoXModel(Model):
407438
def set_gguf_parameters(self):
@@ -1041,6 +1072,24 @@ def set_vocab(self):
10411072
self._set_vocab_sentencepiece()
10421073

10431074

1075+
class MiniCPMModel(Model):
1076+
def set_gguf_parameters(self):
1077+
block_count = self.hparams["num_hidden_layers"]
1078+
self.gguf_writer.add_name("MiniCPM")
1079+
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
1080+
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
1081+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
1082+
self.gguf_writer.add_block_count(block_count)
1083+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
1084+
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
1085+
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
1086+
self.gguf_writer.add_file_type(self.ftype)
1087+
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
1088+
1089+
def set_vocab(self):
1090+
self._set_vocab_hf()
1091+
1092+
10441093
class QwenModel(Model):
10451094
@staticmethod
10461095
def token_bytes_to_string(b):

gguf-py/gguf/constants.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class MODEL_ARCH(IntEnum):
104104
CODESHELL = auto()
105105
ORION = auto()
106106
INTERNLM2 = auto()
107+
MINICPM = auto()
107108

108109

109110
class MODEL_TENSOR(IntEnum):
@@ -156,6 +157,7 @@ class MODEL_TENSOR(IntEnum):
156157
MODEL_ARCH.CODESHELL: "codeshell",
157158
MODEL_ARCH.ORION: "orion",
158159
MODEL_ARCH.INTERNLM2: "internlm2",
160+
MODEL_ARCH.MINICPM: "minicpm",
159161
}
160162

161163
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -464,6 +466,25 @@ class MODEL_TENSOR(IntEnum):
464466
MODEL_TENSOR.FFN_DOWN,
465467
MODEL_TENSOR.FFN_UP,
466468
],
469+
MODEL_ARCH.MINICPM: [
470+
MODEL_TENSOR.TOKEN_EMBD,
471+
MODEL_TENSOR.OUTPUT_NORM,
472+
MODEL_TENSOR.ROPE_FREQS,
473+
MODEL_TENSOR.ATTN_NORM,
474+
MODEL_TENSOR.ATTN_Q,
475+
MODEL_TENSOR.ATTN_K,
476+
MODEL_TENSOR.ATTN_V,
477+
MODEL_TENSOR.ATTN_OUT,
478+
MODEL_TENSOR.ATTN_ROT_EMBD,
479+
MODEL_TENSOR.FFN_GATE_INP,
480+
MODEL_TENSOR.FFN_NORM,
481+
MODEL_TENSOR.FFN_GATE,
482+
MODEL_TENSOR.FFN_DOWN,
483+
MODEL_TENSOR.FFN_UP,
484+
MODEL_TENSOR.FFN_GATE_EXP,
485+
MODEL_TENSOR.FFN_DOWN_EXP,
486+
MODEL_TENSOR.FFN_UP_EXP,
487+
],
467488
# TODO
468489
}
469490

llama.cpp

Lines changed: 189 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ enum llm_arch {
205205
LLM_ARCH_CODESHELL,
206206
LLM_ARCH_ORION,
207207
LLM_ARCH_INTERNLM2,
208+
LLM_ARCH_MINICPM,
208209
LLM_ARCH_UNKNOWN,
209210
};
210211

@@ -228,6 +229,7 @@ static std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
228229
{ LLM_ARCH_CODESHELL, "codeshell" },
229230
{ LLM_ARCH_ORION, "orion" },
230231
{ LLM_ARCH_INTERNLM2, "internlm2" },
232+
{ LLM_ARCH_MINICPM, "minicpm" },
231233
};
232234

233235
enum llm_kv {
@@ -690,6 +692,29 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
690692
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
691693
},
692694
},
695+
{
696+
LLM_ARCH_MINICPM,
697+
{
698+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
699+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
700+
{ LLM_TENSOR_OUTPUT, "output" },
701+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
702+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
703+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
704+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
705+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
706+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
707+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
708+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
709+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
710+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
711+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
712+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
713+
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
714+
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
715+
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
716+
},
717+
},
693718
{
694719
LLM_ARCH_UNKNOWN,
695720
{
@@ -1390,6 +1415,7 @@ enum e_model {
13901415
MODEL_UNKNOWN,
13911416
MODEL_0_5B,
13921417
MODEL_1B,
1418+
MODEL_2B,
13931419
MODEL_3B,
13941420
MODEL_4B,
13951421
MODEL_7B,
@@ -2748,6 +2774,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
27482774
static const char * llama_model_type_name(e_model type) {
27492775
switch (type) {
27502776
case MODEL_1B: return "1B";
2777+
case MODEL_2B: return "2B";
27512778
case MODEL_3B: return "3B";
27522779
case MODEL_7B: return "7B";
27532780
case MODEL_8B: return "8B";
@@ -2887,6 +2914,13 @@ static void llm_load_hparams(
28872914
default: model.type = e_model::MODEL_UNKNOWN;
28882915
}
28892916
} break;
2917+
case LLM_ARCH_MINICPM:
2918+
{
2919+
switch (hparams.n_layer) {
2920+
case 40: model.type = e_model::MODEL_2B; break;
2921+
default: model.type = e_model::MODEL_UNKNOWN;
2922+
}
2923+
} break;
28902924
case LLM_ARCH_FALCON:
28912925
{
28922926
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -3524,13 +3558,16 @@ static bool llm_load_tensors(
35243558
switch (model.arch) {
35253559
case LLM_ARCH_LLAMA:
35263560
case LLM_ARCH_REFACT:
3561+
case LLM_ARCH_MINICPM:
35273562
{
35283563
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
35293564

35303565
// output
35313566
{
35323567
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
3533-
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
3568+
if (model.arch != LLM_ARCH_MINICPM){
3569+
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
3570+
}
35343571
}
35353572

35363573
for (int i = 0; i < n_layer; ++i) {
@@ -6781,6 +6818,153 @@ struct llm_build_context {
67816818
return gf;
67826819
}
67836820

6821+
// ref: https://arxiv.org/abs/2203.03466
6822+
// https://github.com/ggerganov/llama.cpp/issues/5276#issuecomment-1925774738
6823+
// based on the original build_llama() function
6824+
struct ggml_cgraph * build_minicpm() {
6825+
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
6826+
6827+
const int64_t n_embd_head = hparams.n_embd_head_v;
6828+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
6829+
GGML_ASSERT(n_embd_head == hparams.n_rot);
6830+
6831+
const int64_t n_embd = hparams.n_embd;
6832+
//TODO: if the model varies, these parameters need to be read from the model
6833+
const int64_t n_embd_base = 256;
6834+
const float scale_embd = 12.0f;
6835+
const float scale_depth = 1.4f;
6836+
6837+
struct ggml_tensor * cur;
6838+
struct ggml_tensor * inpL;
6839+
6840+
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
6841+
cb(inpL, "inp_embd", -1);
6842+
6843+
// scale the input embeddings
6844+
inpL = ggml_scale(ctx0, inpL, scale_embd);
6845+
cb(inpL, "inp_scaled", -1);
6846+
6847+
// inp_pos - contains the positions
6848+
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
6849+
cb(inp_pos, "inp_pos", -1);
6850+
6851+
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
6852+
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
6853+
cb(KQ_mask, "KQ_mask", -1);
6854+
6855+
// shift the entire K-cache if needed
6856+
if (do_rope_shift) {
6857+
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
6858+
}
6859+
6860+
for (int il = 0; il < n_layer; ++il) {
6861+
struct ggml_tensor * inpSA = inpL;
6862+
6863+
// norm
6864+
cur = llm_build_norm(ctx0, inpL, hparams,
6865+
model.layers[il].attn_norm, NULL,
6866+
LLM_NORM_RMS, cb, il);
6867+
cb(cur, "attn_norm", il);
6868+
6869+
// self-attention
6870+
{
6871+
// compute Q and K and RoPE them
6872+
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
6873+
cb(Qcur, "Qcur", il);
6874+
if (model.layers[il].bq) {
6875+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
6876+
cb(Qcur, "Qcur", il);
6877+
}
6878+
6879+
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
6880+
cb(Kcur, "Kcur", il);
6881+
if (model.layers[il].bk) {
6882+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
6883+
cb(Kcur, "Kcur", il);
6884+
}
6885+
6886+
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
6887+
cb(Vcur, "Vcur", il);
6888+
if (model.layers[il].bv) {
6889+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
6890+
cb(Vcur, "Vcur", il);
6891+
}
6892+
6893+
Qcur = ggml_rope_custom(
6894+
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
6895+
hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
6896+
ext_factor, attn_factor, beta_fast, beta_slow
6897+
);
6898+
cb(Qcur, "Qcur", il);
6899+
6900+
Kcur = ggml_rope_custom(
6901+
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
6902+
hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
6903+
ext_factor, attn_factor, beta_fast, beta_slow
6904+
);
6905+
cb(Kcur, "Kcur", il);
6906+
6907+
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
6908+
model.layers[il].wo, model.layers[il].bo,
6909+
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
6910+
cb(cur, "kqv_out", il);
6911+
}
6912+
6913+
// scale_res - scale the hidden states for residual connection
6914+
const float scale_res = scale_depth/sqrtf(float(n_layer));
6915+
cur = ggml_scale(ctx0, cur, scale_res);
6916+
cb(cur, "hidden_scaled", -1);
6917+
6918+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
6919+
cb(ffn_inp, "ffn_inp", il);
6920+
6921+
// feed-forward network
6922+
{
6923+
cur = llm_build_norm(ctx0, ffn_inp, hparams,
6924+
model.layers[il].ffn_norm, NULL,
6925+
LLM_NORM_RMS, cb, il);
6926+
cb(cur, "ffn_norm", il);
6927+
6928+
cur = llm_build_ffn(ctx0, cur,
6929+
model.layers[il].ffn_up, NULL,
6930+
model.layers[il].ffn_gate, NULL,
6931+
model.layers[il].ffn_down, NULL,
6932+
NULL,
6933+
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
6934+
cb(cur, "ffn_out", il);
6935+
}
6936+
6937+
// scale the hidden states for residual connection
6938+
cur = ggml_scale(ctx0, cur, scale_res);
6939+
cb(cur, "hidden_scaled_ffn", -1);
6940+
6941+
cur = ggml_add(ctx0, cur, ffn_inp);
6942+
cb(cur, "l_out", il);
6943+
6944+
// input for next layer
6945+
inpL = cur;
6946+
}
6947+
6948+
cur = inpL;
6949+
6950+
cur = llm_build_norm(ctx0, cur, hparams,
6951+
model.output_norm, NULL,
6952+
LLM_NORM_RMS, cb, -1);
6953+
cb(cur, "result_norm", -1);
6954+
6955+
// lm_head scaling
6956+
const float scale_lmhead = float(n_embd_base)/float(n_embd);
6957+
cur = ggml_scale(ctx0, cur, scale_lmhead);
6958+
cb(cur, "lmhead_scaling", -1);
6959+
6960+
// lm_head
6961+
cur = ggml_mul_mat(ctx0, model.tok_embd, cur);
6962+
cb(cur, "result_output", -1);
6963+
6964+
ggml_build_forward_expand(gf, cur);
6965+
6966+
return gf;
6967+
}
67846968
};
67856969

67866970
static struct ggml_cgraph * llama_build_graph(
@@ -6943,6 +7127,10 @@ static struct ggml_cgraph * llama_build_graph(
69437127
{
69447128
result = llm.build_internlm2();
69457129
} break;
7130+
case LLM_ARCH_MINICPM:
7131+
{
7132+
result = llm.build_minicpm();
7133+
} break;
69467134
default:
69477135
GGML_ASSERT(false);
69487136
}

0 commit comments

Comments
 (0)