Skip to content

Commit e07d03b

Browse files
committed
Add Command-R Model
Information about the Command-R model can be found at: https://huggingface.co/CohereForAI/c4ai-command-r-v01 1) Download Command-R Hugging Face safetensors: git lfs install git clone https://huggingface.co/CohereForAI/c4ai-command-r-v01 2) Convert safetensors to GGUF format: python3 convert-hf-to-gguf.py --outtype f16 ./c4ai-command-r-v01
1 parent 306d34b commit e07d03b

File tree

4 files changed

+178
-0
lines changed

4 files changed

+178
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ Typically finetunes of the base models below are supported as well.
110110
- [x] [CodeShell](https://github.com/WisdomShell/codeshell)
111111
- [x] [Gemma](https://ai.google.dev/gemma)
112112
- [x] [Mamba](https://github.com/state-spaces/mamba)
113+
- [x] [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)
113114

114115
**Multimodal models:**
115116

convert-hf-to-gguf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1965,6 +1965,11 @@ def write_tensors(self):
19651965
self.gguf_writer.add_tensor(new_name, data)
19661966

19671967

1968+
@Model.register("CohereForCausalLM")
1969+
class CommandR2Model(Model):
1970+
model_arch = gguf.MODEL_ARCH.COMMAND_R
1971+
1972+
19681973
###### CONVERSION LOGIC ######
19691974

19701975

gguf-py/gguf/constants.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ class MODEL_ARCH(IntEnum):
120120
GEMMA = auto()
121121
STARCODER2 = auto()
122122
MAMBA = auto()
123+
COMMAND_R = auto()
123124

124125

125126
class MODEL_TENSOR(IntEnum):
@@ -186,6 +187,7 @@ class MODEL_TENSOR(IntEnum):
186187
MODEL_ARCH.GEMMA: "gemma",
187188
MODEL_ARCH.STARCODER2: "starcoder2",
188189
MODEL_ARCH.MAMBA: "mamba",
190+
MODEL_ARCH.COMMAND_R: "command-r",
189191
}
190192

191193
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -578,6 +580,18 @@ class MODEL_TENSOR(IntEnum):
578580
MODEL_TENSOR.SSM_D,
579581
MODEL_TENSOR.SSM_OUT,
580582
],
583+
MODEL_ARCH.COMMAND_R: [
584+
MODEL_TENSOR.TOKEN_EMBD,
585+
MODEL_TENSOR.OUTPUT_NORM,
586+
MODEL_TENSOR.ATTN_NORM,
587+
MODEL_TENSOR.ATTN_Q,
588+
MODEL_TENSOR.ATTN_K,
589+
MODEL_TENSOR.ATTN_V,
590+
MODEL_TENSOR.ATTN_OUT,
591+
MODEL_TENSOR.FFN_GATE,
592+
MODEL_TENSOR.FFN_DOWN,
593+
MODEL_TENSOR.FFN_UP,
594+
],
581595
# TODO
582596
}
583597

llama.cpp

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ enum llm_arch {
214214
LLM_ARCH_GEMMA,
215215
LLM_ARCH_STARCODER2,
216216
LLM_ARCH_MAMBA,
217+
LLM_ARCH_COMMAND_R,
217218
LLM_ARCH_UNKNOWN,
218219
};
219220

@@ -243,6 +244,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
243244
{ LLM_ARCH_GEMMA, "gemma" },
244245
{ LLM_ARCH_STARCODER2, "starcoder2" },
245246
{ LLM_ARCH_MAMBA, "mamba" },
247+
{ LLM_ARCH_COMMAND_R, "command-r" },
246248
{ LLM_ARCH_UNKNOWN, "(unknown)" },
247249
};
248250

@@ -836,6 +838,21 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
836838
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
837839
},
838840
},
841+
{
842+
LLM_ARCH_COMMAND_R,
843+
{
844+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
845+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
846+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
847+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
848+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
849+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
850+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
851+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
852+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
853+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
854+
},
855+
},
839856
{
840857
LLM_ARCH_UNKNOWN,
841858
{
@@ -1610,6 +1627,7 @@ enum e_model {
16101627
MODEL_20B,
16111628
MODEL_30B,
16121629
MODEL_34B,
1630+
MODEL_35B,
16131631
MODEL_40B,
16141632
MODEL_65B,
16151633
MODEL_70B,
@@ -3237,6 +3255,7 @@ static const char * llama_model_type_name(e_model type) {
32373255
case MODEL_20B: return "20B";
32383256
case MODEL_30B: return "30B";
32393257
case MODEL_34B: return "34B";
3258+
case MODEL_35B: return "35B";
32403259
case MODEL_40B: return "40B";
32413260
case MODEL_65B: return "65B";
32423261
case MODEL_70B: return "70B";
@@ -3628,6 +3647,13 @@ static void llm_load_hparams(
36283647
default: model.type = e_model::MODEL_UNKNOWN;
36293648
}
36303649
} break;
3650+
case LLM_ARCH_COMMAND_R:
3651+
{
3652+
switch (hparams.n_layer) {
3653+
case 40: model.type = e_model::MODEL_35B; break;
3654+
default: model.type = e_model::MODEL_UNKNOWN;
3655+
}
3656+
} break;
36313657
default: (void)0;
36323658
}
36333659

@@ -4131,6 +4157,7 @@ static bool llm_load_tensors(
41314157
case LLM_ARCH_LLAMA:
41324158
case LLM_ARCH_REFACT:
41334159
case LLM_ARCH_MINICPM:
4160+
case LLM_ARCH_COMMAND_R:
41344161
{
41354162
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
41364163

@@ -8302,6 +8329,132 @@ struct llm_build_context {
83028329

83038330
return gf;
83048331
}
8332+
8333+
// FIXME: based on llama right now
8334+
struct ggml_cgraph * build_command_r() {
8335+
8336+
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
8337+
8338+
const int64_t n_embd_head = hparams.n_embd_head_v;
8339+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
8340+
8341+
struct ggml_tensor * cur;
8342+
struct ggml_tensor * inpL;
8343+
8344+
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
8345+
cb(inpL, "inp_embd", -1);
8346+
8347+
// inp_pos - contains the positions
8348+
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
8349+
cb(inp_pos, "inp_pos", -1);
8350+
8351+
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
8352+
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
8353+
cb(KQ_mask, "KQ_mask", -1);
8354+
8355+
for (int il = 0; il < n_layer; ++il) {
8356+
struct ggml_tensor * inpSA = inpL;
8357+
8358+
// norm
8359+
cur = llm_build_norm(ctx0, inpL, hparams,
8360+
model.layers[il].attn_norm, NULL,
8361+
LLM_NORM_RMS, cb, il);
8362+
cb(cur, "attn_norm", il);
8363+
8364+
cur = llm_build_ffn(ctx0, cur,
8365+
model.layers[il].ffn_up, NULL,
8366+
model.layers[il].ffn_gate, NULL,
8367+
model.layers[il].ffn_down, NULL,
8368+
NULL,
8369+
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
8370+
cb(cur, "ffn_out", il);
8371+
8372+
// self-attention
8373+
{
8374+
// compute Q and K and RoPE them
8375+
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
8376+
cb(Qcur, "Qcur", il);
8377+
if (model.layers[il].bq) {
8378+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
8379+
cb(Qcur, "Qcur", il);
8380+
}
8381+
8382+
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
8383+
cb(Kcur, "Kcur", il);
8384+
if (model.layers[il].bk) {
8385+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
8386+
cb(Kcur, "Kcur", il);
8387+
}
8388+
8389+
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
8390+
cb(Vcur, "Vcur", il);
8391+
if (model.layers[il].bv) {
8392+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
8393+
cb(Vcur, "Vcur", il);
8394+
}
8395+
8396+
Qcur = ggml_rope_custom(
8397+
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
8398+
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
8399+
ext_factor, attn_factor, beta_fast, beta_slow
8400+
);
8401+
cb(Qcur, "Qcur", il);
8402+
8403+
Kcur = ggml_rope_custom(
8404+
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
8405+
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
8406+
ext_factor, attn_factor, beta_fast, beta_slow
8407+
);
8408+
cb(Kcur, "Kcur", il);
8409+
8410+
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
8411+
model.layers[il].wo, model.layers[il].bo,
8412+
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
8413+
cb(cur, "kqv_out", il);
8414+
}
8415+
8416+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
8417+
cb(ffn_inp, "ffn_inp", il);
8418+
8419+
// feed-forward network
8420+
{
8421+
cur = llm_build_norm(ctx0, ffn_inp, hparams,
8422+
model.layers[il].ffn_norm, NULL,
8423+
LLM_NORM_RMS, cb, il);
8424+
cb(cur, "ffn_norm", il);
8425+
8426+
cur = llm_build_ffn(ctx0, cur,
8427+
model.layers[il].ffn_up, NULL,
8428+
model.layers[il].ffn_gate, NULL,
8429+
model.layers[il].ffn_down, NULL,
8430+
NULL,
8431+
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
8432+
cb(cur, "ffn_out", il);
8433+
}
8434+
8435+
cur = ggml_add(ctx0, cur, ffn_inp);
8436+
cb(cur, "l_out", il);
8437+
8438+
// input for next layer
8439+
inpL = cur;
8440+
}
8441+
8442+
cur = inpL;
8443+
8444+
cur = llm_build_norm(ctx0, cur, hparams,
8445+
model.output_norm, NULL,
8446+
LLM_NORM_RMS, cb, -1);
8447+
cb(cur, "result_norm", -1);
8448+
8449+
// lm_head
8450+
cur = ggml_mul_mat(ctx0, model.output, cur);
8451+
cb(cur, "result_output", -1);
8452+
8453+
ggml_build_forward_expand(gf, cur);
8454+
8455+
return gf;
8456+
8457+
}
83058458
};
83068459

83078460
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -8473,6 +8626,10 @@ static struct ggml_cgraph * llama_build_graph(
84738626
{
84748627
result = llm.build_mamba();
84758628
} break;
8629+
case LLM_ARCH_COMMAND_R:
8630+
{
8631+
result = llm.build_command_r();
8632+
} break;
84768633
default:
84778634
GGML_ASSERT(false);
84788635
}
@@ -13053,6 +13210,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1305313210
case LLM_ARCH_ORION:
1305413211
case LLM_ARCH_INTERNLM2:
1305513212
case LLM_ARCH_MINICPM:
13213+
case LLM_ARCH_COMMAND_R:
1305613214
return LLAMA_ROPE_TYPE_NORM;
1305713215

1305813216
// the pairs of head values are offset by n_rot/2

0 commit comments

Comments
 (0)