Skip to content

Commit bb1c81c

Browse files
committed
context : simplify sbatch logic
ggml-ci
1 parent a540bcd commit bb1c81c

File tree

6 files changed

+65
-65
lines changed

6 files changed

+65
-65
lines changed

src/llama-batch.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
189189
return ubatch;
190190
}
191191

192-
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
192+
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193193
GGML_ASSERT(batch.n_tokens >= 0);
194194
this->batch = &batch;
195195
this->n_embd = n_embd;
@@ -203,6 +203,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
203203
for (size_t i = 0; i < n_tokens; ++i) {
204204
ids[i] = i;
205205
}
206+
206207
if (simple_split) {
207208
seq.resize(1);
208209
llama_sbatch_seq & s = seq[0];
@@ -212,6 +213,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
212213
s.length = n_tokens;
213214
return;
214215
}
216+
215217
std::sort(ids.begin(), ids.end(),
216218
[&batch](size_t a, size_t b) {
217219
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
@@ -239,6 +241,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
239241
return n_seq_a > n_seq_b;
240242
}
241243
);
244+
242245
// init seq
243246
llama_sbatch_seq * last_seq = nullptr;
244247

@@ -262,6 +265,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
262265
seq.push_back(new_seq);
263266
last_seq = &seq.back();
264267
}
268+
265269
// keep shared prompts first at the end, then sort by length descending.
266270
std::sort(seq.begin(), seq.end(),
267271
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {

src/llama-batch.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ struct llama_sbatch {
7070
// sequence-wise split
7171
llama_ubatch split_seq(size_t n_ubatch);
7272

73-
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
73+
llama_sbatch() = default;
74+
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
7475
};
7576

7677
// temporary allocate memory for the input batch if needed

src/llama-context.cpp

Lines changed: 40 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -799,9 +799,6 @@ enum llama_pooling_type llama_context::pooling_type() const {
799799
}
800800

801801
float * llama_context::get_logits() {
802-
// reorder logits for backward compatibility
803-
output_reorder();
804-
805802
return logits;
806803
}
807804

@@ -844,9 +841,6 @@ float * llama_context::get_logits_ith(int32_t i) {
844841
}
845842

846843
float * llama_context::get_embeddings() {
847-
// reorder embeddings for backward compatibility
848-
output_reorder();
849-
850844
return embd;
851845
}
852846

@@ -1028,7 +1022,7 @@ int llama_context::encode(llama_batch & inp_batch) {
10281022

10291023
const int64_t n_embd = hparams.n_embd;
10301024

1031-
sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
1025+
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
10321026

10331027
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
10341028

@@ -1219,13 +1213,7 @@ int llama_context::decode(llama_batch & inp_batch) {
12191213
n_outputs_all = 1;
12201214
}
12211215

1222-
const bool logits_all = n_outputs_all == n_tokens_all;
1223-
1224-
const bool is_recurrent = llama_model_is_recurrent(&model);
1225-
1226-
sbatch.from_batch(batch, n_embd,
1227-
/* simple_split */ !is_recurrent,
1228-
/* logits_all */ logits_all);
1216+
llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
12291217

12301218
// reserve output buffer
12311219
if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -1382,18 +1370,52 @@ int llama_context::decode(llama_batch & inp_batch) {
13821370
{
13831371
bool sorted_output = true;
13841372

1385-
GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
1373+
auto & out_ids = sbatch.out_ids;
1374+
1375+
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
13861376

13871377
for (int64_t i = 0; i < n_outputs_all; ++i) {
1388-
int64_t out_id = sbatch.out_ids[i];
1378+
int64_t out_id = out_ids[i];
13891379
output_ids[out_id] = i;
13901380
if (out_id != i) {
13911381
sorted_output = false;
13921382
}
13931383
}
13941384

1395-
if (sorted_output) {
1396-
sbatch.out_ids.clear();
1385+
// make the outputs have the same order they had in the user-provided batch
1386+
// note: this is mostly relevant for recurrent models atm
1387+
if (!sorted_output) {
1388+
const uint32_t n_vocab = model.vocab.n_tokens();
1389+
const uint32_t n_embd = model.hparams.n_embd;
1390+
1391+
GGML_ASSERT((size_t) n_outputs == out_ids.size());
1392+
1393+
// TODO: is there something more efficient which also minimizes swaps?
1394+
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1395+
for (int32_t i = 0; i < n_outputs - 1; ++i) {
1396+
int32_t j_min = i;
1397+
for (int32_t j = i + 1; j < n_outputs; ++j) {
1398+
if (out_ids[j] < out_ids[j_min]) {
1399+
j_min = j;
1400+
}
1401+
}
1402+
if (j_min == i) { continue; }
1403+
std::swap(out_ids[i], out_ids[j_min]);
1404+
if (logits_size > 0) {
1405+
for (uint32_t k = 0; k < n_vocab; k++) {
1406+
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1407+
}
1408+
}
1409+
if (embd_size > 0) {
1410+
for (uint32_t k = 0; k < n_embd; k++) {
1411+
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1412+
}
1413+
}
1414+
}
1415+
std::fill(output_ids.begin(), output_ids.end(), -1);
1416+
for (int32_t i = 0; i < n_outputs; ++i) {
1417+
output_ids[out_ids[i]] = i;
1418+
}
13971419
}
13981420
}
13991421

@@ -1504,44 +1526,6 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
15041526
return n_outputs_max;
15051527
}
15061528

1507-
void llama_context::output_reorder() {
1508-
auto & out_ids = sbatch.out_ids;
1509-
if (!out_ids.empty()) {
1510-
const uint32_t n_vocab = model.vocab.n_tokens();
1511-
const uint32_t n_embd = model.hparams.n_embd;
1512-
1513-
GGML_ASSERT((size_t) n_outputs == out_ids.size());
1514-
1515-
// TODO: is there something more efficient which also minimizes swaps?
1516-
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1517-
for (int32_t i = 0; i < n_outputs - 1; ++i) {
1518-
int32_t j_min = i;
1519-
for (int32_t j = i + 1; j < n_outputs; ++j) {
1520-
if (out_ids[j] < out_ids[j_min]) {
1521-
j_min = j;
1522-
}
1523-
}
1524-
if (j_min == i) { continue; }
1525-
std::swap(out_ids[i], out_ids[j_min]);
1526-
if (logits_size > 0) {
1527-
for (uint32_t k = 0; k < n_vocab; k++) {
1528-
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1529-
}
1530-
}
1531-
if (embd_size > 0) {
1532-
for (uint32_t k = 0; k < n_embd; k++) {
1533-
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1534-
}
1535-
}
1536-
}
1537-
std::fill(output_ids.begin(), output_ids.end(), -1);
1538-
for (int32_t i = 0; i < n_outputs; ++i) {
1539-
output_ids[out_ids[i]] = i;
1540-
}
1541-
out_ids.clear();
1542-
}
1543-
}
1544-
15451529
//
15461530
// graph
15471531
//
@@ -1982,8 +1966,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
19821966
{
19831967
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
19841968

1985-
output_reorder();
1986-
19871969
const auto n_outputs = this->n_outputs;
19881970
const auto & output_ids = this->output_ids;
19891971

src/llama-context.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,6 @@ struct llama_context {
137137
// Returns max number of outputs for which space was reserved.
138138
int32_t output_reserve(int32_t n_outputs);
139139

140-
// make the outputs have the same order they had in the user-provided batch
141-
// TODO: maybe remove this
142-
void output_reorder();
143-
144140
//
145141
// graph
146142
//
@@ -196,7 +192,6 @@ struct llama_context {
196192
llama_cparams cparams;
197193
llama_adapter_cvec cvec;
198194
llama_adapter_loras loras;
199-
llama_sbatch sbatch;
200195

201196
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
202197

src/llama-kv-cache.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,12 @@ bool llama_kv_cache_unified::find_slot(
476476
return true;
477477
}
478478

479+
llama_sbatch llama_kv_cache_unified::sbatch_init(
480+
const llama_batch & batch,
481+
bool logits_all) {
482+
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
483+
}
484+
479485
llama_ubatch llama_kv_cache_unified::ubatch_next(
480486
llama_sbatch & sbatch,
481487
uint32_t n_ubatch,
@@ -1547,6 +1553,12 @@ bool llama_kv_cache_recurrent::find_slot(
15471553
return n >= n_seqs;
15481554
}
15491555

1556+
llama_sbatch llama_kv_cache_recurrent::sbatch_init(
1557+
const llama_batch & batch,
1558+
bool logits_all) {
1559+
return llama_sbatch(batch, hparams.n_embd, false, logits_all);
1560+
}
1561+
15501562
llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
15511563
if (embd_pooled) {
15521564
// Pooled embeddings cannot be split across ubatches (yet)

src/llama-kv-cache.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ struct llama_kv_cache : public llama_memory_i {
4545

4646
virtual bool find_slot(const llama_ubatch & batch) = 0;
4747

48+
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
49+
4850
// different KV caches require different batch splitting strategies
4951
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
5052

@@ -143,6 +145,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
143145
// to the first cell of the slot.
144146
bool find_slot(const llama_ubatch & batch) override;
145147

148+
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
149+
146150
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
147151

148152
static uint32_t get_padding(const llama_cparams & cparams);
@@ -269,6 +273,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
269273
// to the first cell of the slot.
270274
bool find_slot(const llama_ubatch & batch) override;
271275

276+
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
277+
272278
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
273279

274280
// find how many cells are currently in use

0 commit comments

Comments
 (0)