Skip to content

Commit 2116f48

Browse files
committed
Add support for the cohere2 model architecture.
1 parent d408bb9 commit 2116f48

File tree

3 files changed

+221
-0
lines changed

3 files changed

+221
-0
lines changed

convert_hf_to_gguf.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3167,6 +3167,24 @@ def set_gguf_parameters(self):
31673167
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
31683168

31693169

3170+
@Model.register("Cohere2ForCausalLM")
3171+
class Cohere2Model(Model):
3172+
model_arch = gguf.MODEL_ARCH.COHERE2
3173+
3174+
def set_gguf_parameters(self):
3175+
super().set_gguf_parameters()
3176+
3177+
self.gguf_writer.add_logit_scale(self.hparams["logit_scale"])
3178+
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
3179+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
3180+
3181+
rotary_pct = self.hparams["rotary_pct"]
3182+
hidden_size = self.hparams["hidden_size"]
3183+
num_attention_heads = self.hparams["num_attention_heads"]
3184+
self.gguf_writer.add_rope_dimension_count(int(rotary_pct * (hidden_size // num_attention_heads)))
3185+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
3186+
3187+
31703188
@Model.register("OlmoForCausalLM")
31713189
@Model.register("OLMoForCausalLM")
31723190
class OlmoModel(Model):

gguf-py/gguf/constants.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ class MODEL_ARCH(IntEnum):
254254
MAMBA = auto()
255255
XVERSE = auto()
256256
COMMAND_R = auto()
257+
COHERE2 = auto()
257258
DBRX = auto()
258259
OLMO = auto()
259260
OLMO2 = auto()
@@ -435,6 +436,7 @@ class MODEL_TENSOR(IntEnum):
435436
MODEL_ARCH.MAMBA: "mamba",
436437
MODEL_ARCH.XVERSE: "xverse",
437438
MODEL_ARCH.COMMAND_R: "command-r",
439+
MODEL_ARCH.COHERE2: "cohere2",
438440
MODEL_ARCH.DBRX: "dbrx",
439441
MODEL_ARCH.OLMO: "olmo",
440442
MODEL_ARCH.OLMO2: "olmo2",
@@ -1114,6 +1116,18 @@ class MODEL_TENSOR(IntEnum):
11141116
MODEL_TENSOR.ATTN_K_NORM,
11151117
MODEL_TENSOR.ATTN_Q_NORM,
11161118
],
1119+
MODEL_ARCH.COHERE2: [
1120+
MODEL_TENSOR.TOKEN_EMBD,
1121+
MODEL_TENSOR.OUTPUT_NORM,
1122+
MODEL_TENSOR.ATTN_NORM,
1123+
MODEL_TENSOR.ATTN_Q,
1124+
MODEL_TENSOR.ATTN_K,
1125+
MODEL_TENSOR.ATTN_V,
1126+
MODEL_TENSOR.ATTN_OUT,
1127+
MODEL_TENSOR.FFN_GATE,
1128+
MODEL_TENSOR.FFN_DOWN,
1129+
MODEL_TENSOR.FFN_UP,
1130+
],
11171131
MODEL_ARCH.DBRX: [
11181132
MODEL_TENSOR.TOKEN_EMBD,
11191133
MODEL_TENSOR.OUTPUT_NORM,

src/llama.cpp

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ enum llm_arch {
178178
LLM_ARCH_MAMBA,
179179
LLM_ARCH_XVERSE,
180180
LLM_ARCH_COMMAND_R,
181+
LLM_ARCH_COHERE2,
181182
LLM_ARCH_DBRX,
182183
LLM_ARCH_OLMO,
183184
LLM_ARCH_OLMO2,
@@ -235,6 +236,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
235236
{ LLM_ARCH_MAMBA, "mamba" },
236237
{ LLM_ARCH_XVERSE, "xverse" },
237238
{ LLM_ARCH_COMMAND_R, "command-r" },
239+
{ LLM_ARCH_COHERE2, "cohere2" },
238240
{ LLM_ARCH_DBRX, "dbrx" },
239241
{ LLM_ARCH_OLMO, "olmo" },
240242
{ LLM_ARCH_OLMO2, "olmo2" },
@@ -1240,6 +1242,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
12401242
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
12411243
},
12421244
},
1245+
{
1246+
LLM_ARCH_COHERE2,
1247+
{
1248+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1249+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1250+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1251+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1252+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1253+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1254+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1255+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1256+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1257+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1258+
},
1259+
},
12431260
{
12441261
LLM_ARCH_DBRX,
12451262
{
@@ -6110,6 +6127,16 @@ static void llm_load_hparams(
61106127
default: model.type = e_model::MODEL_UNKNOWN;
61116128
}
61126129
} break;
6130+
case LLM_ARCH_COHERE2:
6131+
{
6132+
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
6133+
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
6134+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
6135+
switch (hparams.n_layer) {
6136+
case 32: model.type = e_model::MODEL_8B; break;
6137+
default: model.type = e_model::MODEL_UNKNOWN;
6138+
}
6139+
} break;
61136140
case LLM_ARCH_DBRX:
61146141
{
61156142
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -8863,6 +8890,32 @@ static bool llm_load_tensors(
88638890
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
88648891
}
88658892
} break;
8893+
case LLM_ARCH_COHERE2:
8894+
{
8895+
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
8896+
8897+
// output
8898+
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
8899+
// init output from the input tok embed
8900+
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab },
8901+
llama_model_loader::TENSOR_DUPLICATED);
8902+
8903+
for (int i = 0; i < n_layer; ++i) {
8904+
auto & layer = model.layers[i];
8905+
8906+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
8907+
8908+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0);
8909+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
8910+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
8911+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
8912+
8913+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
8914+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
8915+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
8916+
}
8917+
}
8918+
break;
88668919
case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed
88678920
{
88688921
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -14783,6 +14836,137 @@ struct llm_build_context {
1478314836

1478414837
}
1478514838

14839+
struct ggml_cgraph * build_cohere2() {
14840+
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
14841+
14842+
const int64_t n_embd_head = hparams.n_embd_head_v;
14843+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
14844+
const float f_logit_scale = hparams.f_logit_scale;
14845+
14846+
struct ggml_tensor * cur;
14847+
struct ggml_tensor * inpL;
14848+
14849+
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
14850+
14851+
// inp_pos - contains the positions
14852+
struct ggml_tensor * inp_pos = build_inp_pos();
14853+
14854+
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
14855+
// cohere2 requires different mask for layers using sliding window (SWA)
14856+
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
14857+
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
14858+
14859+
// sliding window switch pattern
14860+
const int32_t sliding_window_pattern = 4;
14861+
14862+
for (int il = 0; il < n_layer; ++il) {
14863+
// three layers sliding window attention (window size 4096) and ROPE
14864+
// fourth layer uses global attention without positional embeddings
14865+
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
14866+
struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
14867+
14868+
// norm
14869+
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il);
14870+
cb(cur, "attn_norm", il);
14871+
struct ggml_tensor * ffn_inp = cur;
14872+
14873+
// self-attention
14874+
{
14875+
// rope freq factors for 128k context
14876+
struct ggml_tensor * rope_factors = build_rope_factors(il);
14877+
14878+
// compute Q and K and RoPE them
14879+
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
14880+
cb(Qcur, "Qcur", il);
14881+
if (model.layers[il].bq) {
14882+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
14883+
cb(Qcur, "Qcur", il);
14884+
}
14885+
14886+
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
14887+
cb(Kcur, "Kcur", il);
14888+
if (model.layers[il].bk) {
14889+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
14890+
cb(Kcur, "Kcur", il);
14891+
}
14892+
14893+
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
14894+
cb(Vcur, "Vcur", il);
14895+
if (model.layers[il].bv) {
14896+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
14897+
cb(Vcur, "Vcur", il);
14898+
}
14899+
14900+
if (is_sliding) {
14901+
Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
14902+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
14903+
beta_fast, beta_slow);
14904+
cb(Qcur, "Qcur", il);
14905+
14906+
Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
14907+
rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
14908+
attn_factor, beta_fast, beta_slow);
14909+
cb(Kcur, "Kcur", il);
14910+
} else {
14911+
// For non-sliding layers, just reshape without applying RoPE
14912+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
14913+
cb(Qcur, "Qcur", il);
14914+
14915+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
14916+
cb(Kcur, "Kcur", il);
14917+
}
14918+
14919+
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
14920+
KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
14921+
}
14922+
14923+
if (il == n_layer - 1) {
14924+
// skip computing output for unused tokens
14925+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
14926+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
14927+
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
14928+
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
14929+
}
14930+
14931+
struct ggml_tensor * attn_out = cur;
14932+
14933+
// feed-forward network
14934+
{
14935+
cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
14936+
NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
14937+
cb, il);
14938+
cb(cur, "ffn_out", il);
14939+
}
14940+
14941+
// add together residual + FFN + self-attention
14942+
cur = ggml_add(ctx0, cur, inpL);
14943+
cur = ggml_add(ctx0, cur, attn_out);
14944+
cur = lctx.cvec.apply_to(ctx0, cur, il);
14945+
cb(cur, "l_out", il);
14946+
14947+
// input for next layer
14948+
inpL = cur;
14949+
}
14950+
14951+
cur = inpL;
14952+
14953+
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1);
14954+
cb(cur, "result_norm", -1);
14955+
14956+
// lm_head
14957+
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
14958+
14959+
if (f_logit_scale) {
14960+
cur = ggml_scale(ctx0, cur, f_logit_scale);
14961+
}
14962+
14963+
cb(cur, "result_output", -1);
14964+
14965+
ggml_build_forward_expand(gf, cur);
14966+
14967+
return gf;
14968+
}
14969+
1478614970
// ref: https://allenai.org/olmo
1478714971
// based on the original build_llama() function, changes:
1478814972
// * non-parametric layer norm
@@ -17530,6 +17714,10 @@ static struct ggml_cgraph * llama_build_graph(
1753017714
{
1753117715
result = llm.build_command_r();
1753217716
} break;
17717+
case LLM_ARCH_COHERE2:
17718+
{
17719+
result = llm.build_cohere2();
17720+
} break;
1753317721
case LLM_ARCH_DBRX:
1753417722
{
1753517723
result = llm.build_dbrx();
@@ -20802,6 +20990,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
2080220990
case LLM_ARCH_MINICPM:
2080320991
case LLM_ARCH_XVERSE:
2080420992
case LLM_ARCH_COMMAND_R:
20993+
case LLM_ARCH_COHERE2:
2080520994
case LLM_ARCH_OLMO:
2080620995
case LLM_ARCH_ARCTIC:
2080720996
case LLM_ARCH_DEEPSEEK:

0 commit comments

Comments
 (0)