@@ -13840,7 +13840,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
13840
13840
}
13841
13841
}
13842
13842
13843
- if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
13843
+ if (cparams.embeddings && cparams. pooling_type == LLAMA_POOLING_TYPE_MEAN) {
13844
13844
const int64_t n_tokens = batch.n_tokens;
13845
13845
13846
13846
GGML_ASSERT(lctx.inp_mean);
@@ -13872,7 +13872,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
13872
13872
}
13873
13873
}
13874
13874
13875
- if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
13875
+ if (cparams.embeddings && cparams. pooling_type == LLAMA_POOLING_TYPE_CLS) {
13876
13876
const int64_t n_tokens = batch.n_tokens;
13877
13877
13878
13878
GGML_ASSERT(lctx.inp_cls);
@@ -13893,7 +13893,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
13893
13893
}
13894
13894
}
13895
13895
13896
- if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
13896
+ if (cparams.embeddings && cparams. pooling_type == LLAMA_POOLING_TYPE_LAST) {
13897
13897
const int64_t n_tokens = batch.n_tokens;
13898
13898
13899
13899
GGML_ASSERT(lctx.inp_cls);
@@ -14181,14 +14181,15 @@ static int llama_decode_internal(
14181
14181
std::vector<llama_seq_id *> seq_id_arr;
14182
14182
std::vector<std::vector<llama_seq_id>> seq_id;
14183
14183
14184
+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
14185
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
14186
+
14184
14187
// count outputs
14185
- if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) {
14186
- n_outputs = n_tokens_all;
14187
- } else if (batch_all.logits) {
14188
+ if (batch_all.logits && !embd_pooled) {
14188
14189
for (uint32_t i = 0; i < n_tokens_all; ++i) {
14189
14190
n_outputs += batch_all.logits[i] != 0;
14190
14191
}
14191
- } else if (lctx.logits_all) {
14192
+ } else if (lctx.logits_all || embd_pooled ) {
14192
14193
n_outputs = n_tokens_all;
14193
14194
} else {
14194
14195
// keep last output only
@@ -14234,7 +14235,7 @@ static int llama_decode_internal(
14234
14235
{
14235
14236
int32_t n_outputs_new = 0;
14236
14237
14237
- if (u_batch.logits) {
14238
+ if (u_batch.logits && !embd_pooled ) {
14238
14239
for (uint32_t i = 0; i < n_tokens; i++) {
14239
14240
n_outputs_new += u_batch.logits[i] != 0;
14240
14241
}
@@ -18533,6 +18534,7 @@ struct llama_context_params llama_context_default_params() {
18533
18534
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
18534
18535
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
18535
18536
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
18537
+ /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
18536
18538
/*.rope_freq_base =*/ 0.0f,
18537
18539
/*.rope_freq_scale =*/ 0.0f,
18538
18540
/*.yarn_ext_factor =*/ -1.0f,
@@ -18785,7 +18787,6 @@ struct llama_context * llama_new_context_with_model(
18785
18787
}
18786
18788
18787
18789
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
18788
- cparams.causal_attn = hparams.causal_attn;
18789
18790
18790
18791
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
18791
18792
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
@@ -18795,6 +18796,12 @@ struct llama_context * llama_new_context_with_model(
18795
18796
}
18796
18797
}
18797
18798
18799
+ if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
18800
+ cparams.causal_attn = hparams.causal_attn;
18801
+ } else {
18802
+ cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
18803
+ }
18804
+
18798
18805
if (params.seed == LLAMA_DEFAULT_SEED) {
18799
18806
params.seed = time(NULL);
18800
18807
}
0 commit comments