@@ -680,6 +680,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
680
680
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
681
681
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
682
682
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
683
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
683
684
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
684
685
},
685
686
},
@@ -1921,6 +1922,16 @@ struct llama_layer {
1921
1922
// mamba bias
1922
1923
struct ggml_tensor * ssm_conv1d_b;
1923
1924
struct ggml_tensor * ssm_dt_b;
1925
+
1926
+ //glu mlp (jina-bert)
1927
+ struct ggml_tensor * mlp_gated_layer_w;
1928
+
1929
+ struct ggml_tensor * mlp_wo_w;
1930
+ struct ggml_tensor * mlp_wo_b;
1931
+
1932
+ struct ggml_tensor * mlp_norm_w;
1933
+ struct ggml_tensor * mlp_norm_b;
1934
+
1924
1935
};
1925
1936
1926
1937
struct llama_kv_cell {
@@ -4813,7 +4824,6 @@ static bool llm_load_tensors(
4813
4824
}
4814
4825
} break;
4815
4826
case LLM_ARCH_BERT:
4816
- case LLM_ARCH_JINA_BERT:
4817
4827
case LLM_ARCH_NOMIC_BERT:
4818
4828
{
4819
4829
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -4831,7 +4841,7 @@ static bool llm_load_tensors(
4831
4841
4832
4842
auto & layer = model.layers[i];
4833
4843
4834
- if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT ) {
4844
+ if (model.arch == LLM_ARCH_BERT) {
4835
4845
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
4836
4846
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd});
4837
4847
@@ -4852,7 +4862,7 @@ static bool llm_load_tensors(
4852
4862
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
4853
4863
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
4854
4864
4855
- if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT ) {
4865
+ if (model.arch == LLM_ARCH_BERT) {
4856
4866
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
4857
4867
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
4858
4868
@@ -4865,6 +4875,44 @@ static bool llm_load_tensors(
4865
4875
layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd});
4866
4876
}
4867
4877
} break;
4878
+ case LLM_ARCH_JINA_BERT:
4879
+ {
4880
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // word_embeddings
4881
+ model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); //token_type_embeddings
4882
+ model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm
4883
+ model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias? Not sure needed
4884
+
4885
+ for (int i = 0; i < n_layer; ++i) {
4886
+ ggml_context * ctx_layer = ctx_for_layer(i);
4887
+ ggml_context * ctx_split = ctx_for_layer_split(i);
4888
+
4889
+ auto & layer = model.layers[i]; // JinaBertLayer
4890
+
4891
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
4892
+ layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd});
4893
+
4894
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
4895
+ layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa});
4896
+
4897
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
4898
+ layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa});
4899
+
4900
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); //output_dens
4901
+ layer.bo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); //output_dens
4902
+
4903
+ layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm
4904
+ layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd});
4905
+
4906
+ // TODO: HANDLE ALL THE MLP
4907
+ layer.mlp_gated_layer_w = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, 2 * n_ff});
4908
+
4909
+ layer.mlp_wo_w = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
4910
+ layer.mlp_wo_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
4911
+
4912
+ layer.mlp_norm_w = ml.create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
4913
+ layer.mlp_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd});
4914
+ }
4915
+ } break;
4868
4916
case LLM_ARCH_BLOOM:
4869
4917
{
4870
4918
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -9713,6 +9761,7 @@ static struct ggml_cgraph * llama_build_graph(
9713
9761
result = llm.build_refact();
9714
9762
} break;
9715
9763
case LLM_ARCH_BERT:
9764
+ case LLM_ARCH_JINA_BERT:
9716
9765
case LLM_ARCH_NOMIC_BERT:
9717
9766
{
9718
9767
result = llm.build_bert();
0 commit comments