Skip to content

Commit b00d38b

Browse files
author
Joan Martinez
committed
feat: embedding gets results
1 parent a40156a commit b00d38b

File tree

4 files changed

+42
-7
lines changed

4 files changed

+42
-7
lines changed

convert-hf-to-gguf.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2170,6 +2170,29 @@ def get_tensors(self):
21702170
class JinaBertModel(BertModel):
21712171
model_arch = gguf.MODEL_ARCH.JINA_BERT
21722172

2173+
def __init__(self, *args, **kwargs):
2174+
super().__init__(*args, **kwargs)
2175+
self.intermediate_size = self.hparams["intermediate_size"]
2176+
2177+
def get_tensors(self):
2178+
import string
2179+
print(f'Intermediate SIZE: {self.intermediate_size}')
2180+
2181+
for name, data in super().get_tensors():
2182+
if 'gated_layers' in name:
2183+
print(f'name {name} => {data.shape}')
2184+
d1 = data[:self.intermediate_size, :]
2185+
name1 = name.replace('gated_layers', 'gated_layers_w')
2186+
d2 = data[self.intermediate_size:, :]
2187+
name2 = name.replace('gated_layers', 'gated_layers_v')
2188+
print(f'd1 {d1.shape}, d2 {d2.shape}')
2189+
yield name1, d1
2190+
yield name2, d2
2191+
continue
2192+
2193+
yield name, data
2194+
2195+
21732196
@Model.register("GemmaForCausalLM")
21742197
class GemmaModel(Model):
21752198
model_arch = gguf.MODEL_ARCH.GEMMA

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ class MODEL_TENSOR(IntEnum):
369369
MODEL_TENSOR.ATTN_V,
370370
MODEL_TENSOR.ATTN_OUT,
371371
MODEL_TENSOR.FFN_UP,
372+
MODEL_TENSOR.FFN_GATE,
372373
MODEL_TENSOR.FFN_DOWN,
373374
MODEL_TENSOR.LAYER_OUT_NORM,
374375
],

gguf-py/gguf/tensor_mapping.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ class TensorNameMap:
228228
"model.layers.{bid}.feed_forward.w3", # internlm2
229229
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
230230
"model.layers.{bid}.mlp.c_fc", # starcoder2
231-
"encoder.layer.{bid}.mlp.gated_layers", # jina-bert
231+
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert
232232
),
233233

234234
MODEL_TENSOR.FFN_UP_EXP: (
@@ -249,6 +249,7 @@ class TensorNameMap:
249249
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
250250
"model.layers.{bid}.feed_forward.w1", # internlm2
251251
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
252+
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert
252253
),
253254

254255
MODEL_TENSOR.FFN_GATE_EXP: (

llama.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4870,7 +4870,7 @@ static bool llm_load_tensors(
48704870
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // word_embeddings
48714871
model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); //token_type_embeddings
48724872
model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm
4873-
model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias? Not sure needed
4873+
model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias
48744874

48754875
for (int i = 0; i < n_layer; ++i) {
48764876
ggml_context * ctx_layer = ctx_for_layer(i);
@@ -4893,8 +4893,8 @@ static bool llm_load_tensors(
48934893
layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm
48944894
layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd});
48954895

4896-
// TODO: HANDLE ALL THE MLP
4897-
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, 2 * n_ff});
4896+
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
4897+
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
48984898

48994899
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
49004900
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
@@ -5851,7 +5851,7 @@ static struct ggml_tensor * llm_build_ffn(
58515851
llm_ffn_gate_type type_gate,
58525852
const llm_build_cb & cb,
58535853
int il) {
5854-
struct ggml_tensor * tmp = ggml_mul_mat(ctx, up, cur);
5854+
struct ggml_tensor * tmp = up ? ggml_mul_mat(ctx, up, cur): cur;
58555855
cb(tmp, "ffn_up", il);
58565856

58575857
if (up_b) {
@@ -7522,8 +7522,11 @@ struct llm_build_context {
75227522

75237523
struct ggml_tensor * cur;
75247524
struct ggml_tensor * inpL;
7525+
struct ggml_tensor * inp_pos = nullptr;
75257526

7526-
struct ggml_tensor * inp_pos = build_inp_pos();
7527+
if (model.arch != LLM_ARCH_JINA_BERT) {
7528+
inp_pos = build_inp_pos();
7529+
}
75277530
struct ggml_tensor * inp_mean = build_inp_mean();
75287531
struct ggml_tensor * inp_cls = build_inp_cls();
75297532

@@ -7644,13 +7647,20 @@ struct llm_build_context {
76447647
cb(ffn_inp, "ffn_inp", il);
76457648

76467649
// feed-forward network
7647-
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT) {
7650+
if (model.arch == LLM_ARCH_BERT) {
76487651
cur = llm_build_ffn(ctx0, cur,
76497652
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
76507653
NULL, NULL,
76517654
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
76527655
NULL,
76537656
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
7657+
} else if (model.arch == LLM_ARCH_JINA_BERT) {
7658+
cur = llm_build_ffn(ctx0, cur,
7659+
model.layers[il].ffn_up, NULL,
7660+
model.layers[il].ffn_gate, NULL,
7661+
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
7662+
NULL,
7663+
LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
76547664
} else {
76557665
cur = llm_build_ffn(ctx0, cur,
76567666
model.layers[il].ffn_up, NULL,

0 commit comments

Comments
 (0)