Skip to content

Commit f164ba9

Browse files
committed
batch : rework llama_batch_allocr
ggml-ci
1 parent a681b4b commit f164ba9

File tree

5 files changed

+94
-55
lines changed

5 files changed

+94
-55
lines changed

src/llama-batch.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,15 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
279279
);
280280
}
281281

282-
llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
282+
llama_batch_allocr::llama_batch_allocr() = default;
283+
284+
bool llama_batch_allocr::init(struct llama_batch in_batch, llama_pos p0) {
285+
GGML_ASSERT(in_batch.n_tokens > 0);
286+
287+
clear();
288+
283289
batch = in_batch;
284-
GGML_ASSERT(batch.n_tokens > 0);
290+
285291
if (!batch.pos) {
286292
assert(p0 >= 0);
287293
pos.resize(batch.n_tokens);
@@ -290,13 +296,15 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
290296
}
291297
batch.pos = pos.data();
292298
}
299+
293300
if (!batch.n_seq_id) {
294301
n_seq_id.resize(batch.n_tokens);
295302
for (int32_t i = 0; i < batch.n_tokens; i++) {
296303
n_seq_id[i] = seq_id_0.size();
297304
}
298305
batch.n_seq_id = n_seq_id.data();
299306
}
307+
300308
if (!batch.seq_id) {
301309
seq_id.resize(batch.n_tokens + 1);
302310
seq_id[batch.n_tokens] = NULL;
@@ -305,12 +313,27 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
305313
}
306314
batch.seq_id = seq_id.data();
307315
}
316+
308317
if (!batch.logits) {
309318
// by default return the output only for the last token
310319
output.resize(batch.n_tokens);
311320
output[output.size() - 1] = true;
312321
batch.logits = output.data();
313322
}
323+
324+
return true;
325+
}
326+
327+
const llama_batch & llama_batch_allocr::get_batch() const {
328+
return batch;
329+
}
330+
331+
void llama_batch_allocr::clear() {
332+
batch = {};
333+
pos.clear();
334+
n_seq_id.clear();
335+
seq_id.clear();
336+
output.clear();
314337
}
315338

316339
//

src/llama-batch.h

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ struct llama_ubatch {
1818
llama_token * token; // [n_tokens]
1919
float * embd; // [n_embd, n_tokens]
2020
llama_pos * pos; // [n_tokens]
21-
int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
22-
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
21+
int32_t * n_seq_id; // [n_seqs]
22+
llama_seq_id ** seq_id; // [n_seqs]
2323
int8_t * output; // [n_tokens]
2424
};
2525

@@ -78,15 +78,23 @@ struct llama_sbatch {
7878
};
7979

8080
// temporary allocate memory for the input batch if needed
81-
struct llama_batch_allocr {
82-
struct llama_batch batch;
81+
class llama_batch_allocr {
82+
public:
83+
llama_batch_allocr();
84+
85+
// optionally fulfill the batch returned by llama_batch_get_one
86+
bool init(llama_batch in_batch, llama_pos p0);
87+
88+
const llama_batch & get_batch() const;
89+
90+
private:
91+
void clear();
92+
93+
llama_batch batch;
8394

8495
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
8596
std::vector<llama_pos> pos;
8697
std::vector<int32_t> n_seq_id;
8798
std::vector<llama_seq_id *> seq_id;
8899
std::vector<int8_t> output;
89-
90-
// optionally fulfill the batch returned by llama_batch_get_one
91-
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
92100
};

src/llama-context.cpp

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "llama-context.h"
22

33
#include "llama-impl.h"
4+
#include "llama-batch.h"
45
#include "llama-io.h"
56
#include "llama-memory.h"
67
#include "llama-mmap.h"
@@ -18,7 +19,8 @@
1819
llama_context::llama_context(
1920
const llama_model & model,
2021
llama_context_params params) :
21-
model(model) {
22+
model(model),
23+
batch_allocr(std::make_unique<llama_batch_allocr>()) {
2224
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
2325

2426
t_start_us = model.t_start_us;
@@ -494,7 +496,7 @@ float * llama_context::get_logits() {
494496
}
495497

496498
float * llama_context::get_logits_ith(int32_t i) {
497-
int32_t j = -1;
499+
int64_t j = -1;
498500

499501
try {
500502
if (logits == nullptr) {
@@ -517,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
517519
}
518520
if (j >= n_outputs) {
519521
// This should not happen
520-
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
522+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
521523
}
522524

523525
return logits + j*model.vocab.n_tokens();
@@ -536,7 +538,7 @@ float * llama_context::get_embeddings() {
536538
}
537539

538540
float * llama_context::get_embeddings_ith(int32_t i) {
539-
int32_t j = -1;
541+
int64_t j = -1;
540542

541543
try {
542544
if (embd == nullptr) {
@@ -559,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
559561
}
560562
if (j >= n_outputs) {
561563
// This should not happen
562-
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
564+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
563565
}
564566

565567
return embd + j*model.hparams.n_embd;
@@ -727,18 +729,19 @@ int llama_context::encode(llama_batch & inp_batch) {
727729

728730
// temporary allocate memory for the input batch if needed
729731
// note: during encode, we always pass the full sequence starting from pos = 0
730-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
732+
batch_allocr->init(inp_batch, inp_batch.pos ? -1 : 0);
731733

732-
const llama_batch & batch = batch_allocr.batch;
733-
const int32_t n_tokens = batch.n_tokens;
734+
const llama_batch & batch = batch_allocr->get_batch();
735+
736+
const uint32_t n_tokens = batch.n_tokens;
734737

735738
const auto & hparams = model.hparams;
736739

737740
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
738741

739742
// TODO: move the validation to the llama_batch_allocr
740743
if (batch.token) {
741-
for (int32_t i = 0; i < n_tokens; ++i) {
744+
for (uint32_t i = 0; i < n_tokens; ++i) {
742745
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
743746
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
744747
return -1;
@@ -775,7 +778,7 @@ int llama_context::encode(llama_batch & inp_batch) {
775778
return -2;
776779
};
777780

778-
for (int32_t i = 0; i < n_tokens; ++i) {
781+
for (uint32_t i = 0; i < n_tokens; ++i) {
779782
output_ids[i] = i;
780783
}
781784

@@ -831,7 +834,8 @@ int llama_context::encode(llama_batch & inp_batch) {
831834

832835
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
833836

834-
for (int32_t i = 0; i < n_tokens; i++) {
837+
// TODO: fix sequence indexing
838+
for (uint32_t i = 0; i < n_tokens; i++) {
835839
const llama_seq_id seq_id = ubatch.seq_id[i][0];
836840
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
837841
continue;
@@ -881,7 +885,7 @@ int llama_context::encode(llama_batch & inp_batch) {
881885
// TODO: the seuqence indexing here is likely not correct in the general case
882886
// probably works only for split_simple
883887
cross.seq_ids_enc.resize(n_tokens);
884-
for (int32_t i = 0; i < n_tokens; i++) {
888+
for (uint32_t i = 0; i < n_tokens; i++) {
885889
cross.seq_ids_enc[i].clear();
886890
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
887891
llama_seq_id seq_id = ubatch.seq_id[i][s];
@@ -912,30 +916,30 @@ int llama_context::decode(llama_batch & inp_batch) {
912916
}
913917

914918
// temporary allocate memory for the input batch if needed
915-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
919+
batch_allocr->init(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
916920

917-
const llama_batch & batch = batch_allocr.batch;
921+
const llama_batch & batch = batch_allocr->get_batch();
918922

919923
const auto & vocab = model.vocab;
920924
const auto & hparams = model.hparams;
921925

922926
const int32_t n_vocab = vocab.n_tokens();
927+
const int64_t n_embd = hparams.n_embd;
923928

924-
const int64_t n_tokens_all = batch.n_tokens;
925-
const int64_t n_embd = hparams.n_embd;
929+
const uint32_t n_tokens_all = batch.n_tokens;
926930

927931
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
928932

929933
// TODO: move the validation to the llama_batch_allocr
930934
if (batch.token) {
931-
for (int64_t i = 0; i < n_tokens_all; ++i) {
935+
for (uint32_t i = 0; i < n_tokens_all; ++i) {
932936
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
933-
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
937+
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
934938
return -1;
935939
}
936940

937941
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
938-
LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
942+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
939943
return -1;
940944
}
941945
}
@@ -944,7 +948,7 @@ int llama_context::decode(llama_batch & inp_batch) {
944948
// this indicates we are doing pooled embedding
945949
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
946950

947-
int64_t n_outputs_all = 0;
951+
uint32_t n_outputs_all = 0;
948952

949953
// count outputs
950954
for (uint32_t i = 0; i < n_tokens_all; ++i) {
@@ -954,7 +958,7 @@ int llama_context::decode(llama_batch & inp_batch) {
954958
if (embd_pooled) {
955959
// require that all tokens are output
956960
if (n_outputs_all != n_tokens_all) {
957-
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 ", n_tokens_all = %" PRId64 ")\n",
961+
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
958962
__func__, n_outputs_all, n_tokens_all);
959963
return -1;
960964
}
@@ -1024,7 +1028,7 @@ int llama_context::decode(llama_batch & inp_batch) {
10241028

10251029
// reserve output buffer
10261030
if (output_reserve(n_outputs_all) < n_outputs_all) {
1027-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
1031+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
10281032
return -2;
10291033
};
10301034

@@ -1063,6 +1067,7 @@ int llama_context::decode(llama_batch & inp_batch) {
10631067
pos_min[s] = std::numeric_limits<llama_pos>::max();
10641068
}
10651069

1070+
// TODO: fix sequence indexing
10661071
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
10671072
const auto & seq_id = ubatch.seq_id[i][0];
10681073

@@ -1176,14 +1181,14 @@ int llama_context::decode(llama_batch & inp_batch) {
11761181
n_outputs = n_outputs_all;
11771182

11781183
// set output mappings
1179-
{
1184+
if (n_outputs > 0) {
11801185
bool sorted_output = true;
11811186

11821187
auto & out_ids = mstate->out_ids();
11831188

1184-
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1189+
GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
11851190

1186-
for (int64_t i = 0; i < n_outputs_all; ++i) {
1191+
for (int64_t i = 0; i < n_outputs; ++i) {
11871192
int64_t out_id = out_ids[i];
11881193
output_ids[out_id] = i;
11891194
if (out_id != i) {
@@ -1195,20 +1200,22 @@ int llama_context::decode(llama_batch & inp_batch) {
11951200
// note: this is mostly relevant for recurrent models atm
11961201
if (!sorted_output) {
11971202
const uint32_t n_vocab = model.vocab.n_tokens();
1198-
const uint32_t n_embd = model.hparams.n_embd;
1203+
const uint64_t n_embd = model.hparams.n_embd;
11991204

12001205
GGML_ASSERT((size_t) n_outputs == out_ids.size());
12011206

12021207
// TODO: is there something more efficient which also minimizes swaps?
12031208
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1204-
for (int32_t i = 0; i < n_outputs - 1; ++i) {
1205-
int32_t j_min = i;
1206-
for (int32_t j = i + 1; j < n_outputs; ++j) {
1209+
for (uint32_t i = 0; i < n_outputs - 1; ++i) {
1210+
uint32_t j_min = i;
1211+
for (uint32_t j = i + 1; j < n_outputs; ++j) {
12071212
if (out_ids[j] < out_ids[j_min]) {
12081213
j_min = j;
12091214
}
12101215
}
1211-
if (j_min == i) { continue; }
1216+
if (j_min == i) {
1217+
continue;
1218+
}
12121219
std::swap(out_ids[i], out_ids[j_min]);
12131220
if (logits_size > 0) {
12141221
for (uint32_t k = 0; k < n_vocab; k++) {
@@ -1221,8 +1228,10 @@ int llama_context::decode(llama_batch & inp_batch) {
12211228
}
12221229
}
12231230
}
1231+
12241232
std::fill(output_ids.begin(), output_ids.end(), -1);
1225-
for (int32_t i = 0; i < n_outputs; ++i) {
1233+
1234+
for (uint32_t i = 0; i < n_outputs; ++i) {
12261235
output_ids[out_ids[i]] = i;
12271236
}
12281237
}
@@ -1242,7 +1251,7 @@ int llama_context::decode(llama_batch & inp_batch) {
12421251
// output
12431252
//
12441253

1245-
int32_t llama_context::output_reserve(int32_t n_outputs) {
1254+
uint32_t llama_context::output_reserve(int32_t n_outputs) {
12461255
const auto & hparams = model.hparams;
12471256
const auto & vocab = model.vocab;
12481257

@@ -1308,8 +1317,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
13081317
// set all ids as invalid (negative)
13091318
std::fill(output_ids.begin(), output_ids.end(), -1);
13101319

1311-
this->n_outputs = 0;
1312-
this->n_outputs_max = n_outputs_max;
1320+
this->n_outputs = 0;
13131321

13141322
return n_outputs_max;
13151323
}
@@ -1800,14 +1808,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
18001808

18011809
std::vector<int32_t> w_output_pos;
18021810

1803-
GGML_ASSERT(n_outputs <= n_outputs_max);
1804-
18051811
w_output_pos.resize(n_outputs);
18061812

18071813
// build a more compact representation of the output ids
18081814
for (size_t i = 0; i < n_batch(); ++i) {
18091815
// map an output id to a position in the batch
1810-
int32_t pos = output_ids[i];
1816+
int64_t pos = output_ids[i];
18111817
if (pos >= 0) {
18121818
GGML_ASSERT(pos < n_outputs);
18131819
w_output_pos[pos] = i;
@@ -2082,7 +2088,7 @@ void llama_context::opt_epoch_iter(
20822088

20832089
embd_seq.clear();
20842090

2085-
int64_t n_outputs_all = n_tokens_all;
2091+
uint32_t n_outputs_all = n_tokens_all;
20862092

20872093
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
20882094
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
@@ -2092,7 +2098,7 @@ void llama_context::opt_epoch_iter(
20922098

20932099
// reserve output buffer
20942100
if (output_reserve(n_outputs_all) < n_outputs_all) {
2095-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
2101+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
20962102
GGML_ABORT("TODO: handle this error");
20972103
};
20982104

0 commit comments

Comments
 (0)