Skip to content

Commit c3ee46f

Browse files
authored
batch : remove logits_all flag (#14141)
ggml-ci
1 parent e2c0b6e commit c3ee46f

10 files changed

+17
-30
lines changed

src/llama-batch.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
105105
ubatch.seq_id = batch->seq_id + seq.offset;
106106
}
107107
}
108-
if (logits_all) {
109-
for (size_t i = 0; i < length; ++i) {
110-
ubatch.output[ubatch.n_tokens + i] = 1;
111-
out_ids.push_back(ids[seq.offset + i]);
112-
}
113-
} else if (batch->logits) {
108+
if (batch->logits) {
114109
if (ubatch.equal_seqs) {
115110
for (size_t i = 0; i < length; ++i) {
116111
size_t id = ids[seq.offset + i];
@@ -197,11 +192,10 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
197192
return ubatch;
198193
}
199194

200-
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
195+
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
201196
GGML_ASSERT(batch.n_tokens >= 0);
202197
this->batch = &batch;
203198
this->n_embd = n_embd;
204-
this->logits_all = logits_all;
205199

206200
n_tokens = batch.n_tokens;
207201
ids.resize(n_tokens);

src/llama-batch.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ struct llama_sbatch {
3939

4040
size_t n_embd;
4141

42-
bool logits_all; // TODO: remove once lctx.logits_all is removed too
43-
4442
// sorted indices into the batch
4543
std::vector<int64_t> ids;
4644
// batch indices of the output
@@ -76,7 +74,7 @@ struct llama_sbatch {
7674
llama_ubatch split_seq(size_t n_ubatch);
7775

7876
llama_sbatch() = default;
79-
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
77+
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
8078
};
8179

8280
// temporary allocate memory for the input batch if needed

src/llama-context.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ int llama_context::encode(llama_batch & inp_batch) {
764764

765765
const int64_t n_embd = hparams.n_embd;
766766

767-
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
767+
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
768768

769769
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
770770

@@ -976,7 +976,7 @@ int llama_context::decode(llama_batch & inp_batch) {
976976
llama_memory_state_ptr mstate;
977977

978978
while (true) {
979-
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
979+
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
980980
if (!mstate) {
981981
return -2;
982982
}
@@ -2080,7 +2080,7 @@ void llama_context::opt_epoch_iter(
20802080

20812081
int64_t n_outputs_all = n_tokens_all;
20822082

2083-
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
2083+
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
20842084
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
20852085
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
20862086
break;

src/llama-kv-cache-recurrent.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,10 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
359359
return result;
360360
}
361361

362-
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
362+
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
363363
GGML_UNUSED(embd_pooled);
364364

365-
auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
365+
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
366366

367367
std::vector<llama_ubatch> ubatches;
368368

src/llama-kv-cache-recurrent.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ class llama_kv_cache_recurrent : public llama_memory_i {
3232
llama_memory_state_ptr init_batch(
3333
const llama_batch & batch,
3434
uint32_t n_ubatch,
35-
bool embd_pooled,
36-
bool logits_all) override;
35+
bool embd_pooled) override;
3736

3837
llama_memory_state_ptr init_full() override;
3938

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,12 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
9595
return kv_swa->seq_pos_max(seq_id);
9696
}
9797

98-
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
98+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
9999
GGML_UNUSED(embd_pooled);
100100

101101
// first try simple split
102102
do {
103-
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
103+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
104104

105105
std::vector<llama_ubatch> ubatches;
106106

@@ -128,7 +128,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
128128

129129
// if it fails, try equal split
130130
do {
131-
auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
131+
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
132132

133133
std::vector<llama_ubatch> ubatches;
134134

src/llama-kv-cache-unified-iswa.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
3434
llama_memory_state_ptr init_batch(
3535
const llama_batch & batch,
3636
uint32_t n_ubatch,
37-
bool embd_pooled,
38-
bool logits_all) override;
37+
bool embd_pooled) override;
3938

4039
llama_memory_state_ptr init_full() override;
4140

src/llama-kv-cache-unified.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,12 +310,11 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
310310
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
311311
const llama_batch & batch,
312312
uint32_t n_ubatch,
313-
bool embd_pooled,
314-
bool logits_all) {
313+
bool embd_pooled) {
315314
GGML_UNUSED(embd_pooled);
316315

317316
do {
318-
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
317+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
319318

320319
std::vector<llama_ubatch> ubatches;
321320
while (sbatch.n_tokens > 0) {

src/llama-kv-cache-unified.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ class llama_kv_cache_unified : public llama_memory_i {
5959
llama_memory_state_ptr init_batch(
6060
const llama_batch & batch,
6161
uint32_t n_ubatch,
62-
bool embd_pooled,
63-
bool logits_all) override;
62+
bool embd_pooled) override;
6463

6564
llama_memory_state_ptr init_full() override;
6665

src/llama-memory.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ struct llama_memory_i {
7373
virtual llama_memory_state_ptr init_batch(
7474
const llama_batch & batch,
7575
uint32_t n_ubatch,
76-
bool embd_pooled,
77-
bool logits_all) = 0;
76+
bool embd_pooled) = 0;
7877

7978
// simulate full cache, used for allocating worst-case compute buffers
8079
virtual llama_memory_state_ptr init_full() = 0;

0 commit comments

Comments
 (0)