@@ -254,6 +254,7 @@ enum llm_kv {
254
254
LLM_KV_TENSOR_DATA_LAYOUT,
255
255
LLM_KV_EXPERT_COUNT,
256
256
LLM_KV_EXPERT_USED_COUNT,
257
+ LLM_KV_POOLING_LAYER,
257
258
258
259
LLM_KV_ATTENTION_HEAD_COUNT,
259
260
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -311,6 +312,7 @@ static std::map<llm_kv, const char *> LLM_KV_NAMES = {
311
312
{ LLM_KV_TENSOR_DATA_LAYOUT, " %s.tensor_data_layout" },
312
313
{ LLM_KV_EXPERT_COUNT, " %s.expert_count" },
313
314
{ LLM_KV_EXPERT_USED_COUNT, " %s.expert_used_count" },
315
+ { LLM_KV_POOLING_LAYER, " %s.pooling_layer" },
314
316
315
317
{ LLM_KV_ATTENTION_HEAD_COUNT, " %s.attention.head_count" },
316
318
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, " %s.attention.head_count_kv" },
@@ -1524,6 +1526,7 @@ struct llama_hparams {
1524
1526
float f_max_alibi_bias;
1525
1527
1526
1528
bool causal_attn = true ;
1529
+ bool pooling_layer = false ;
1527
1530
1528
1531
1529
1532
bool operator !=(const llama_hparams & other) const {
@@ -1586,6 +1589,7 @@ struct llama_cparams {
1586
1589
1587
1590
bool mul_mat_q;
1588
1591
bool offload_kqv;
1592
+ bool do_pooling;
1589
1593
1590
1594
ggml_backend_sched_eval_callback cb_eval;
1591
1595
void * cb_eval_user_data;
@@ -1881,7 +1885,7 @@ struct llama_context {
1881
1885
struct ggml_tensor * inp_pos; // I32 [n_batch]
1882
1886
struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch]
1883
1887
struct ggml_tensor * inp_K_shift; // I32 [n_ctx]
1884
- struct ggml_tensor * inp_sum; // F32 [1 , n_batch]
1888
+ struct ggml_tensor * inp_sum; // F32 [n_batch , n_batch]
1885
1889
1886
1890
#ifdef GGML_USE_MPI
1887
1891
ggml_mpi_context * ctx_mpi = NULL ;
@@ -3038,6 +3042,7 @@ static void llm_load_hparams(
3038
3042
ml.get_key (LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps );
3039
3043
ml.get_key (LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn );
3040
3044
ml.get_key (LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type );
3045
+ ml.get_key (LLM_KV_POOLING_LAYER, hparams.pooling_layer );
3041
3046
3042
3047
switch (hparams.n_layer ) {
3043
3048
case 3 :
@@ -4845,6 +4850,7 @@ struct llm_build_context {
4845
4850
4846
4851
const bool do_rope_shift;
4847
4852
const bool causal_attn;
4853
+ const bool do_pooling;
4848
4854
4849
4855
const llm_build_cb & cb;
4850
4856
@@ -4889,6 +4895,7 @@ struct llm_build_context {
4889
4895
n_orig_ctx (cparams.n_yarn_orig_ctx),
4890
4896
do_rope_shift (worst_case || kv_self.has_shift),
4891
4897
causal_attn (hparams.causal_attn),
4898
+ do_pooling (hparams.pooling_layer && cparams.do_pooling),
4892
4899
cb (cb),
4893
4900
buf_compute_meta (lctx.buf_compute_meta) {
4894
4901
// all initializations should be done in init()
@@ -5737,14 +5744,14 @@ struct llm_build_context {
5737
5744
5738
5745
const int64_t n_embd_head = hparams.n_embd_head_v ;
5739
5746
GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
5740
- GGML_ASSERT (n_embd_head == hparams.n_rot );
5741
5747
5742
5748
struct ggml_tensor * cur;
5743
5749
struct ggml_tensor * inpL;
5744
5750
5745
5751
// get input vectors with right size
5752
+ const size_t stride1 = n_tokens * ggml_type_size (lctx.inp_tokens ->type );
5746
5753
struct ggml_tensor * inp_pos = ggml_view_1d (ctx0, lctx.inp_pos , n_tokens, 0 );
5747
- struct ggml_tensor * inp_sum = ggml_view_1d (ctx0, lctx.inp_sum , n_tokens, 0 );
5754
+ struct ggml_tensor * inp_sum = ggml_view_2d (ctx0, lctx.inp_sum , n_tokens, n_tokens, stride1 , 0 );
5748
5755
5749
5756
// construct input embeddings (token, type, position)
5750
5757
inpL = llm_build_inp_embd (ctx0, hparams, batch, model.tok_embd , lctx.inp_tokens , lctx.inp_embd , cb);
@@ -5817,8 +5824,10 @@ struct llm_build_context {
5817
5824
// final output
5818
5825
cur = inpL;
5819
5826
5820
- // pooling
5821
- cur = ggml_mul_mat (ctx0, inp_sum, ggml_cont (ctx0, ggml_transpose (ctx0, cur)));
5827
+ // pooling layer
5828
+ if (do_pooling) {
5829
+ cur = ggml_mul_mat (ctx0, ggml_cont (ctx0, ggml_transpose (ctx0, cur)), inp_sum);
5830
+ }
5822
5831
cb (cur, " result_embed" , -1 );
5823
5832
5824
5833
ggml_build_forward_expand (gf, cur);
@@ -7384,6 +7393,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
7384
7393
data[i] = lctx.kv_self .cells [i].delta ;
7385
7394
}
7386
7395
}
7396
+
7397
+ if (hparams.pooling_layer && cparams.do_pooling ) {
7398
+ const int64_t n_tokens = batch.n_tokens ;
7399
+
7400
+ GGML_ASSERT (ggml_backend_buffer_is_host (lctx.inp_sum ->buffer ));
7401
+ float * data = (float *) lctx.inp_sum ->data ;
7402
+
7403
+ memset (lctx.inp_sum ->data , 0 , batch.n_tokens * batch.n_tokens * ggml_element_size (lctx.inp_sum ));
7404
+ for (int i = 0 ; i < n_tokens; ++i) {
7405
+ const llama_seq_id seq_id = batch.seq_id [i][0 ];
7406
+ data[seq_id*n_tokens + i] = 1 .0f ;
7407
+ }
7408
+ }
7387
7409
}
7388
7410
7389
7411
// decode a batch of tokens by evaluating the transformer
@@ -7616,10 +7638,11 @@ static int llama_decode_internal(
7616
7638
auto & embedding_out = lctx.embedding ;
7617
7639
7618
7640
const int64_t embed_pos = res ? n_embd * (n_tokens-1 ) : 0 ;
7641
+ const int64_t embed_size = res ? n_embd : n_embd * n_tokens;
7619
7642
7620
- embedding_out.resize (n_embd );
7643
+ embedding_out.resize (embed_size );
7621
7644
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend (lctx.sched , embeddings);
7622
- ggml_backend_tensor_get_async (embeddings_backend, embeddings, embedding_out.data (), embed_pos*sizeof (float ), n_embd *sizeof (float ));
7645
+ ggml_backend_tensor_get_async (embeddings_backend, embeddings, embedding_out.data (), embed_pos*sizeof (float ), embed_size *sizeof (float ));
7623
7646
ggml_backend_synchronize (embeddings_backend);
7624
7647
}
7625
7648
@@ -10930,6 +10953,7 @@ struct llama_context_params llama_context_default_params() {
10930
10953
/* .logits_all =*/ false ,
10931
10954
/* .embedding =*/ false ,
10932
10955
/* .offload_kqv =*/ true ,
10956
+ /* .do_pooling =*/ true ,
10933
10957
};
10934
10958
10935
10959
return result;
@@ -11085,6 +11109,7 @@ struct llama_context * llama_new_context_with_model(
11085
11109
cparams.yarn_beta_slow = params.yarn_beta_slow ;
11086
11110
cparams.mul_mat_q = params.mul_mat_q ;
11087
11111
cparams.offload_kqv = params.offload_kqv ;
11112
+ cparams.do_pooling = params.do_pooling ;
11088
11113
11089
11114
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx ;
11090
11115
cparams.rope_freq_base = params.rope_freq_base == 0 .0f ? hparams.rope_freq_base_train : params.rope_freq_base ;
@@ -11232,7 +11257,7 @@ struct llama_context * llama_new_context_with_model(
11232
11257
// resized during inference, reserve maximum
11233
11258
ctx->logits .reserve (hparams.n_vocab *cparams.n_batch );
11234
11259
11235
- if (params.embedding ){
11260
+ if (params.embedding ) {
11236
11261
ctx->embedding .resize (hparams.n_embd );
11237
11262
}
11238
11263
@@ -11250,7 +11275,7 @@ struct llama_context * llama_new_context_with_model(
11250
11275
ctx->inp_pos = ggml_new_tensor_1d (ctx->ctx_input , GGML_TYPE_I32, cparams.n_batch );
11251
11276
ctx->inp_KQ_mask = ggml_new_tensor_2d (ctx->ctx_input , GGML_TYPE_F32, cparams.n_ctx , cparams.n_batch );
11252
11277
ctx->inp_K_shift = ggml_new_tensor_1d (ctx->ctx_input , GGML_TYPE_I32, cparams.n_ctx );
11253
- ctx->inp_sum = ggml_new_tensor_2d (ctx->ctx_input , GGML_TYPE_F32, 1 , cparams.n_batch );
11278
+ ctx->inp_sum = ggml_new_tensor_2d (ctx->ctx_input , GGML_TYPE_F32, cparams. n_batch , cparams.n_batch );
11254
11279
11255
11280
ggml_set_name (ctx->inp_tokens , " inp_tokens" );
11256
11281
ggml_set_name (ctx->inp_embd , " inp_embd" );
@@ -12108,6 +12133,10 @@ float * llama_get_embeddings(struct llama_context * ctx) {
12108
12133
return ctx->embedding .data ();
12109
12134
}
12110
12135
12136
+ float * llama_get_embeddings_ith (struct llama_context * ctx, int32_t i) {
12137
+ return ctx->embedding .data () + i*ctx->model .hparams .n_embd ;
12138
+ }
12139
+
12111
12140
const char * llama_token_get_text (const struct llama_model * model, llama_token token) {
12112
12141
return model->vocab .id_to_token [token].text .c_str ();
12113
12142
}
0 commit comments