Skip to content

context : simplify output counting logic during decode #14142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
ubatch.seq_id = batch->seq_id + seq.offset;
}
}
if (logits_all) {
for (size_t i = 0; i < length; ++i) {
ubatch.output[ubatch.n_tokens + i] = 1;
out_ids.push_back(ids[seq.offset + i]);
}
} else if (batch->logits) {
if (batch->logits) {
if (ubatch.equal_seqs) {
for (size_t i = 0; i < length; ++i) {
size_t id = ids[seq.offset + i];
Expand Down Expand Up @@ -197,11 +192,10 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
return ubatch;
}

llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
GGML_ASSERT(batch.n_tokens >= 0);
this->batch = &batch;
this->n_embd = n_embd;
this->logits_all = logits_all;

n_tokens = batch.n_tokens;
ids.resize(n_tokens);
Expand Down Expand Up @@ -312,9 +306,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
batch.seq_id = seq_id.data();
}
if (!batch.logits) {
logits.resize(batch.n_tokens);
logits[logits.size() - 1] = true;
batch.logits = logits.data();
// by default return the output only for the last token
output.resize(batch.n_tokens);
output[output.size() - 1] = true;
batch.logits = output.data();
}
}

Expand Down
6 changes: 2 additions & 4 deletions src/llama-batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ struct llama_sbatch {

size_t n_embd;

bool logits_all; // TODO: remove once lctx.logits_all is removed too

// sorted indices into the batch
std::vector<int64_t> ids;
// batch indices of the output
Expand Down Expand Up @@ -76,7 +74,7 @@ struct llama_sbatch {
llama_ubatch split_seq(size_t n_ubatch);

llama_sbatch() = default;
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
};

// temporary allocate memory for the input batch if needed
Expand All @@ -87,7 +85,7 @@ struct llama_batch_allocr {
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> logits;
std::vector<int8_t> output;

// optionally fulfill the batch returned by llama_batch_get_one
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
Expand Down
48 changes: 26 additions & 22 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -758,13 +758,14 @@ int llama_context::encode(llama_batch & inp_batch) {
t_compute_start_us = ggml_time_us();
}

// TODO: this clear of the buffer can easily be forgotten - need something better
embd_seq.clear();

n_queued_tokens += n_tokens;

const int64_t n_embd = hparams.n_embd;

llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);

const llama_ubatch ubatch = sbatch.split_simple(n_tokens);

Expand Down Expand Up @@ -940,6 +941,25 @@ int llama_context::decode(llama_batch & inp_batch) {
}
}

// this indicates we are doing pooled embedding
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;

int64_t n_outputs_all = 0;

// count outputs
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs_all += batch.logits[i] != 0;
}

if (embd_pooled) {
// require that all tokens are output
if (n_outputs_all != n_tokens_all) {
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 ", n_tokens_all = %" PRId64 ")\n",
__func__, n_outputs_all, n_tokens_all);
return -1;
}
}

GGML_ASSERT(n_tokens_all <= cparams.n_batch);

GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
Expand All @@ -949,25 +969,9 @@ int llama_context::decode(llama_batch & inp_batch) {
}
n_queued_tokens += n_tokens_all;

// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;

// TODO: this clear of the buffer can easily be forgotten - need something better
embd_seq.clear();

int64_t n_outputs_all = 0;

// count outputs
if (batch.logits && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs_all += batch.logits[i] != 0;
}
} else if (embd_pooled) {
n_outputs_all = n_tokens_all;
} else {
// keep last output only
n_outputs_all = 1;
}

bool did_optimize = false;

// handle any pending defrags/shifts
Expand All @@ -976,7 +980,7 @@ int llama_context::decode(llama_batch & inp_batch) {
llama_memory_state_ptr mstate;

while (true) {
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
if (!mstate) {
return -2;
}
Expand Down Expand Up @@ -1029,7 +1033,7 @@ int llama_context::decode(llama_batch & inp_batch) {
do {
const auto & ubatch = mstate->get_ubatch();

// count the outputs in this u_batch
// count the outputs in this ubatch
{
int32_t n_outputs_new = 0;

Expand Down Expand Up @@ -2073,14 +2077,14 @@ void llama_context::opt_epoch_iter(

n_queued_tokens += n_tokens_all;

// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
// this indicates we are doing pooled embedding
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;

embd_seq.clear();

int64_t n_outputs_all = n_tokens_all;

auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break;
Expand Down
4 changes: 2 additions & 2 deletions src/llama-kv-cache-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,10 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
return result;
}

llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
GGML_UNUSED(embd_pooled);

auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);

std::vector<llama_ubatch> ubatches;

Expand Down
3 changes: 1 addition & 2 deletions src/llama-kv-cache-recurrent.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ class llama_kv_cache_recurrent : public llama_memory_i {
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
bool embd_pooled) override;

llama_memory_state_ptr init_full() override;

Expand Down
6 changes: 3 additions & 3 deletions src/llama-kv-cache-unified-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
return kv_swa->seq_pos_max(seq_id);
}

llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
GGML_UNUSED(embd_pooled);

// first try simple split
do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);

std::vector<llama_ubatch> ubatches;

Expand Down Expand Up @@ -128,7 +128,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch

// if it fails, try equal split
do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);

std::vector<llama_ubatch> ubatches;

Expand Down
3 changes: 1 addition & 2 deletions src/llama-kv-cache-unified-iswa.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
bool embd_pooled) override;

llama_memory_state_ptr init_full() override;

Expand Down
5 changes: 2 additions & 3 deletions src/llama-kv-cache-unified.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,11 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) {
bool embd_pooled) {
GGML_UNUSED(embd_pooled);

do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);

std::vector<llama_ubatch> ubatches;
while (sbatch.n_tokens > 0) {
Expand Down
3 changes: 1 addition & 2 deletions src/llama-kv-cache-unified.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ class llama_kv_cache_unified : public llama_memory_i {
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
bool embd_pooled) override;

llama_memory_state_ptr init_full() override;

Expand Down
3 changes: 1 addition & 2 deletions src/llama-memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ struct llama_memory_i {
virtual llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) = 0;
bool embd_pooled) = 0;

// simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0;
Expand Down
Loading