@@ -758,6 +758,7 @@ int llama_context::encode(llama_batch & inp_batch) {
758
758
t_compute_start_us = ggml_time_us ();
759
759
}
760
760
761
+ // TODO: this clear of the buffer can easily be forgotten - need something better
761
762
embd_seq.clear ();
762
763
763
764
n_queued_tokens += n_tokens;
@@ -940,6 +941,25 @@ int llama_context::decode(llama_batch & inp_batch) {
940
941
}
941
942
}
942
943
944
+ // this indicates we are doing pooled embedding
945
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
946
+
947
+ int64_t n_outputs_all = 0 ;
948
+
949
+ // count outputs
950
+ for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
951
+ n_outputs_all += batch.logits [i] != 0 ;
952
+ }
953
+
954
+ if (embd_pooled) {
955
+ // require that all tokens are output
956
+ if (n_outputs_all != n_tokens_all) {
957
+ LLAMA_LOG_ERROR (" %s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 " , n_tokens_all = %" PRId64 " )\n " ,
958
+ __func__, n_outputs_all, n_tokens_all);
959
+ return -1 ;
960
+ }
961
+ }
962
+
943
963
GGML_ASSERT (n_tokens_all <= cparams.n_batch );
944
964
945
965
GGML_ASSERT ((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && " non-causal attention requires n_ubatch >= n_tokens" );
@@ -949,25 +969,9 @@ int llama_context::decode(llama_batch & inp_batch) {
949
969
}
950
970
n_queued_tokens += n_tokens_all;
951
971
952
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
953
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
954
-
972
+ // TODO: this clear of the buffer can easily be forgotten - need something better
955
973
embd_seq.clear ();
956
974
957
- int64_t n_outputs_all = 0 ;
958
-
959
- // count outputs
960
- if (batch.logits && !embd_pooled) {
961
- for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
962
- n_outputs_all += batch.logits [i] != 0 ;
963
- }
964
- } else if (embd_pooled) {
965
- n_outputs_all = n_tokens_all;
966
- } else {
967
- // keep last output only
968
- n_outputs_all = 1 ;
969
- }
970
-
971
975
bool did_optimize = false ;
972
976
973
977
// handle any pending defrags/shifts
@@ -1029,7 +1033,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1029
1033
do {
1030
1034
const auto & ubatch = mstate->get_ubatch ();
1031
1035
1032
- // count the outputs in this u_batch
1036
+ // count the outputs in this ubatch
1033
1037
{
1034
1038
int32_t n_outputs_new = 0 ;
1035
1039
@@ -2073,7 +2077,7 @@ void llama_context::opt_epoch_iter(
2073
2077
2074
2078
n_queued_tokens += n_tokens_all;
2075
2079
2076
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
2080
+ // this indicates we are doing pooled embedding
2077
2081
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
2078
2082
2079
2083
embd_seq.clear ();
0 commit comments