Skip to content

Commit f6e1a7a

Browse files
authored
context : simplify output counting logic during decode (#14142)
* batch : remove logits_all flag ggml-ci * context : simplify output counting logic during decode ggml-ci * cont : fix comments
1 parent c3ee46f commit f6e1a7a

File tree

3 files changed

+28
-23
lines changed

3 files changed

+28
-23
lines changed

src/llama-batch.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
306306
batch.seq_id = seq_id.data();
307307
}
308308
if (!batch.logits) {
309-
logits.resize(batch.n_tokens);
310-
logits[logits.size() - 1] = true;
311-
batch.logits = logits.data();
309+
// by default return the output only for the last token
310+
output.resize(batch.n_tokens);
311+
output[output.size() - 1] = true;
312+
batch.logits = output.data();
312313
}
313314
}
314315

src/llama-batch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ struct llama_batch_allocr {
8585
std::vector<llama_pos> pos;
8686
std::vector<int32_t> n_seq_id;
8787
std::vector<llama_seq_id *> seq_id;
88-
std::vector<int8_t> logits;
88+
std::vector<int8_t> output;
8989

9090
// optionally fulfill the batch returned by llama_batch_get_one
9191
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);

src/llama-context.cpp

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,7 @@ int llama_context::encode(llama_batch & inp_batch) {
758758
t_compute_start_us = ggml_time_us();
759759
}
760760

761+
// TODO: this clear of the buffer can easily be forgotten - need something better
761762
embd_seq.clear();
762763

763764
n_queued_tokens += n_tokens;
@@ -940,6 +941,25 @@ int llama_context::decode(llama_batch & inp_batch) {
940941
}
941942
}
942943

944+
// this indicates we are doing pooled embedding
945+
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
946+
947+
int64_t n_outputs_all = 0;
948+
949+
// count outputs
950+
for (uint32_t i = 0; i < n_tokens_all; ++i) {
951+
n_outputs_all += batch.logits[i] != 0;
952+
}
953+
954+
if (embd_pooled) {
955+
// require that all tokens are output
956+
if (n_outputs_all != n_tokens_all) {
957+
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 ", n_tokens_all = %" PRId64 ")\n",
958+
__func__, n_outputs_all, n_tokens_all);
959+
return -1;
960+
}
961+
}
962+
943963
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
944964

945965
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
@@ -949,25 +969,9 @@ int llama_context::decode(llama_batch & inp_batch) {
949969
}
950970
n_queued_tokens += n_tokens_all;
951971

952-
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
953-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
954-
972+
// TODO: this clear of the buffer can easily be forgotten - need something better
955973
embd_seq.clear();
956974

957-
int64_t n_outputs_all = 0;
958-
959-
// count outputs
960-
if (batch.logits && !embd_pooled) {
961-
for (uint32_t i = 0; i < n_tokens_all; ++i) {
962-
n_outputs_all += batch.logits[i] != 0;
963-
}
964-
} else if (embd_pooled) {
965-
n_outputs_all = n_tokens_all;
966-
} else {
967-
// keep last output only
968-
n_outputs_all = 1;
969-
}
970-
971975
bool did_optimize = false;
972976

973977
// handle any pending defrags/shifts
@@ -1029,7 +1033,7 @@ int llama_context::decode(llama_batch & inp_batch) {
10291033
do {
10301034
const auto & ubatch = mstate->get_ubatch();
10311035

1032-
// count the outputs in this u_batch
1036+
// count the outputs in this ubatch
10331037
{
10341038
int32_t n_outputs_new = 0;
10351039

@@ -2073,7 +2077,7 @@ void llama_context::opt_epoch_iter(
20732077

20742078
n_queued_tokens += n_tokens_all;
20752079

2076-
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
2080+
// this indicates we are doing pooled embedding
20772081
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
20782082

20792083
embd_seq.clear();

0 commit comments

Comments
 (0)