@@ -13841,7 +13841,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
13841
13841
}
13842
13842
}
13843
13843
13844
- if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
13844
+ if (cparams.embeddings && cparams. pooling_type == LLAMA_POOLING_TYPE_MEAN) {
13845
13845
const int64_t n_tokens = batch.n_tokens;
13846
13846
13847
13847
GGML_ASSERT(lctx.inp_mean);
@@ -13873,7 +13873,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
13873
13873
}
13874
13874
}
13875
13875
13876
- if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
13876
+ if (cparams.embeddings && cparams. pooling_type == LLAMA_POOLING_TYPE_CLS) {
13877
13877
const int64_t n_tokens = batch.n_tokens;
13878
13878
13879
13879
GGML_ASSERT(lctx.inp_cls);
@@ -13894,7 +13894,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
13894
13894
}
13895
13895
}
13896
13896
13897
- if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
13897
+ if (cparams.embeddings && cparams. pooling_type == LLAMA_POOLING_TYPE_LAST) {
13898
13898
const int64_t n_tokens = batch.n_tokens;
13899
13899
13900
13900
GGML_ASSERT(lctx.inp_cls);
@@ -14182,14 +14182,15 @@ static int llama_decode_internal(
14182
14182
std::vector<llama_seq_id *> seq_id_arr;
14183
14183
std::vector<std::vector<llama_seq_id>> seq_id;
14184
14184
14185
+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
14186
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
14187
+
14185
14188
// count outputs
14186
- if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) {
14187
- n_outputs = n_tokens_all;
14188
- } else if (batch_all.logits) {
14189
+ if (batch_all.logits && !embd_pooled) {
14189
14190
for (uint32_t i = 0; i < n_tokens_all; ++i) {
14190
14191
n_outputs += batch_all.logits[i] != 0;
14191
14192
}
14192
- } else if (lctx.logits_all) {
14193
+ } else if (lctx.logits_all || embd_pooled ) {
14193
14194
n_outputs = n_tokens_all;
14194
14195
} else {
14195
14196
// keep last output only
@@ -14235,7 +14236,7 @@ static int llama_decode_internal(
14235
14236
{
14236
14237
int32_t n_outputs_new = 0;
14237
14238
14238
- if (u_batch.logits) {
14239
+ if (u_batch.logits && !embd_pooled ) {
14239
14240
for (uint32_t i = 0; i < n_tokens; i++) {
14240
14241
n_outputs_new += u_batch.logits[i] != 0;
14241
14242
}
@@ -18534,6 +18535,7 @@ struct llama_context_params llama_context_default_params() {
18534
18535
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
18535
18536
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
18536
18537
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
18538
+ /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
18537
18539
/*.rope_freq_base =*/ 0.0f,
18538
18540
/*.rope_freq_scale =*/ 0.0f,
18539
18541
/*.yarn_ext_factor =*/ -1.0f,
@@ -18786,7 +18788,6 @@ struct llama_context * llama_new_context_with_model(
18786
18788
}
18787
18789
18788
18790
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
18789
- cparams.causal_attn = hparams.causal_attn;
18790
18791
18791
18792
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
18792
18793
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
@@ -18796,6 +18797,12 @@ struct llama_context * llama_new_context_with_model(
18796
18797
}
18797
18798
}
18798
18799
18800
+ if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
18801
+ cparams.causal_attn = hparams.causal_attn;
18802
+ } else {
18803
+ cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
18804
+ }
18805
+
18799
18806
if (params.seed == LLAMA_DEFAULT_SEED) {
18800
18807
params.seed = time(NULL);
18801
18808
}
0 commit comments