@@ -12715,7 +12715,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
12715
12715
}
12716
12716
}
12717
12717
12718
- if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
12718
+ if (cparams.embeddings && cparams. pooling_type == LLAMA_POOLING_TYPE_MEAN) {
12719
12719
const int64_t n_tokens = batch.n_tokens;
12720
12720
12721
12721
GGML_ASSERT(lctx.inp_mean);
@@ -12747,7 +12747,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
12747
12747
}
12748
12748
}
12749
12749
12750
- if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
12750
+ if (cparams.embeddings && cparams. pooling_type == LLAMA_POOLING_TYPE_CLS) {
12751
12751
const int64_t n_tokens = batch.n_tokens;
12752
12752
12753
12753
GGML_ASSERT(lctx.inp_cls);
@@ -12768,7 +12768,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
12768
12768
}
12769
12769
}
12770
12770
12771
- if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
12771
+ if (cparams.embeddings && cparams. pooling_type == LLAMA_POOLING_TYPE_LAST) {
12772
12772
const int64_t n_tokens = batch.n_tokens;
12773
12773
12774
12774
GGML_ASSERT(lctx.inp_cls);
@@ -12990,14 +12990,15 @@ static int llama_decode_internal(
12990
12990
std::vector<llama_seq_id *> seq_id_arr;
12991
12991
std::vector<std::vector<llama_seq_id>> seq_id;
12992
12992
12993
+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
12994
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
12995
+
12993
12996
// count outputs
12994
- if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) {
12995
- n_outputs = n_tokens_all;
12996
- } else if (batch_all.logits) {
12997
+ if (batch_all.logits && !embd_pooled) {
12997
12998
for (uint32_t i = 0; i < n_tokens_all; ++i) {
12998
12999
n_outputs += batch_all.logits[i] != 0;
12999
13000
}
13000
- } else if (lctx.logits_all) {
13001
+ } else if (lctx.logits_all || embd_pooled ) {
13001
13002
n_outputs = n_tokens_all;
13002
13003
} else {
13003
13004
// keep last output only
@@ -13043,7 +13044,7 @@ static int llama_decode_internal(
13043
13044
{
13044
13045
int32_t n_outputs_new = 0;
13045
13046
13046
- if (u_batch.logits) {
13047
+ if (u_batch.logits && !embd_pooled ) {
13047
13048
for (uint32_t i = 0; i < n_tokens; i++) {
13048
13049
n_outputs_new += u_batch.logits[i] != 0;
13049
13050
}
@@ -17202,6 +17203,7 @@ struct llama_context_params llama_context_default_params() {
17202
17203
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
17203
17204
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
17204
17205
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
17206
+ /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
17205
17207
/*.rope_freq_base =*/ 0.0f,
17206
17208
/*.rope_freq_scale =*/ 0.0f,
17207
17209
/*.yarn_ext_factor =*/ -1.0f,
@@ -17448,7 +17450,6 @@ struct llama_context * llama_new_context_with_model(
17448
17450
}
17449
17451
17450
17452
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
17451
- cparams.causal_attn = hparams.causal_attn;
17452
17453
17453
17454
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
17454
17455
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
@@ -17458,6 +17459,12 @@ struct llama_context * llama_new_context_with_model(
17458
17459
}
17459
17460
}
17460
17461
17462
+ if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
17463
+ cparams.causal_attn = hparams.causal_attn;
17464
+ } else {
17465
+ cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
17466
+ }
17467
+
17461
17468
if (params.seed == LLAMA_DEFAULT_SEED) {
17462
17469
params.seed = time(NULL);
17463
17470
}
0 commit comments