@@ -17134,10 +17134,10 @@ static void llama_graph_compute(
17134
17134
//
17135
17135
static int llama_decode_internal(
17136
17136
llama_context & lctx,
17137
- llama_batch batch_all ) { // TODO: rename back to batch
17137
+ llama_batch batch ) {
17138
17138
17139
17139
lctx.is_encoding = false;
17140
- const uint32_t n_tokens_all = batch_all .n_tokens;
17140
+ const uint32_t n_tokens_all = batch .n_tokens;
17141
17141
17142
17142
if (n_tokens_all == 0) {
17143
17143
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
@@ -17148,12 +17148,12 @@ static int llama_decode_internal(
17148
17148
const auto & hparams = model.hparams;
17149
17149
const auto & cparams = lctx.cparams;
17150
17150
17151
- GGML_ASSERT((!batch_all .token && batch_all .embd) || (batch_all .token && !batch_all .embd)); // NOLINT
17151
+ GGML_ASSERT((!batch .token && batch .embd) || (batch .token && !batch .embd)); // NOLINT
17152
17152
17153
- if (batch_all .token) {
17153
+ if (batch .token) {
17154
17154
for (uint32_t i = 0; i < n_tokens_all; ++i) {
17155
- if (batch_all .token[i] < 0 || (uint32_t)batch_all .token[i] >= model.vocab.n_vocab) {
17156
- LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch_all .token[i]);
17155
+ if (batch .token[i] < 0 || (uint32_t)batch .token[i] >= model.vocab.n_vocab) {
17156
+ LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch .token[i]);
17157
17157
return -1;
17158
17158
}
17159
17159
}
@@ -17184,9 +17184,9 @@ static int llama_decode_internal(
17184
17184
lctx.embd_seq.clear();
17185
17185
17186
17186
// count outputs
17187
- if (batch_all .logits && !embd_pooled) {
17187
+ if (batch .logits && !embd_pooled) {
17188
17188
for (uint32_t i = 0; i < n_tokens_all; ++i) {
17189
- n_outputs += batch_all .logits[i] != 0;
17189
+ n_outputs += batch .logits[i] != 0;
17190
17190
}
17191
17191
} else if (lctx.logits_all || embd_pooled) {
17192
17192
n_outputs = n_tokens_all;
@@ -17195,7 +17195,7 @@ static int llama_decode_internal(
17195
17195
n_outputs = 1;
17196
17196
}
17197
17197
17198
- lctx.sbatch.from_batch(batch_all , n_embd,
17198
+ lctx.sbatch.from_batch(batch , n_embd,
17199
17199
/* simple_split */ !kv_self.recurrent,
17200
17200
/* logits_all */ n_outputs == n_tokens_all);
17201
17201
0 commit comments