Skip to content

Commit 7035c79

Browse files
committed
llama : batch
ggml-ci
1 parent a7df071 commit 7035c79

File tree

5 files changed

+341
-318
lines changed

5 files changed

+341
-318
lines changed

src/llama-batch.cpp

Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,306 @@
11
#include "llama-batch.h"
2+
3+
#include <cstring>
4+
#include <algorithm>
5+
6+
llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
7+
// clear empty sequences
8+
// the previous ubatch is assumed to be gone,
9+
// so nothing should refer to values in these sequences anymore.
10+
for (size_t i = seq.size(); i-- > 0;) {
11+
if (seq[i].length == 0) {
12+
seq.pop_back();
13+
} else {
14+
break;
15+
}
16+
}
17+
ubatch_token.resize(!has_embd ? n_ubatch : 0);
18+
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
19+
ubatch_pos.resize(n_ubatch);
20+
ubatch_n_seq_id.resize(n_ubatch);
21+
ubatch_seq_id.resize(n_ubatch);
22+
ubatch_output.resize(n_ubatch);
23+
llama_ubatch ubatch = {
24+
/*equal_seqs =*/ true,
25+
/*n_tokens =*/ 0,
26+
/*n_seq_tokens =*/ 0,
27+
/*n_seqs =*/ 0,
28+
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
29+
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
30+
/*pos =*/ ubatch_pos.data(),
31+
/*n_seq_id =*/ ubatch_n_seq_id.data(),
32+
/*seq_id =*/ ubatch_seq_id.data(),
33+
/*output =*/ ubatch_output.data(),
34+
};
35+
return ubatch;
36+
}
37+
38+
void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
39+
GGML_ASSERT(batch != nullptr);
40+
GGML_ASSERT(length <= seq.length);
41+
// Can only add sequences of equal lengths to a batch,
42+
// otherwise it isn't clear to which sequence a token belongs
43+
GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
44+
GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
45+
// NOTE: loops are separated for cache-friendliness
46+
if (batch->token) {
47+
if (ubatch.equal_seqs) {
48+
for (size_t i = 0; i < length; ++i) {
49+
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
50+
}
51+
} else {
52+
// simple split
53+
ubatch.token = batch->token + seq.offset;
54+
}
55+
} else {
56+
ubatch.token = nullptr;
57+
}
58+
if (batch->embd) {
59+
if (ubatch.equal_seqs) {
60+
for (size_t i = 0; i < length; ++i) {
61+
memcpy(
62+
ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
63+
batch->embd + (n_embd * ids[seq.offset + i]),
64+
n_embd * sizeof(float)
65+
);
66+
}
67+
} else {
68+
// simple split
69+
ubatch.embd = batch->embd + (n_embd * seq.offset);
70+
}
71+
} else {
72+
ubatch.embd = nullptr;
73+
}
74+
if (ubatch.equal_seqs) {
75+
for (size_t i = 0; i < length; ++i) {
76+
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
77+
}
78+
} else {
79+
// simple split
80+
ubatch.pos = batch->pos + seq.offset;
81+
}
82+
if (ubatch.equal_seqs) {
83+
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
84+
if (seq.seq_id) {
85+
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
86+
}
87+
} else {
88+
// simple split
89+
if (batch->n_seq_id) {
90+
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
91+
} else {
92+
for (size_t i = 0; i < length; ++i) {
93+
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
94+
}
95+
}
96+
if (batch->seq_id) {
97+
ubatch.seq_id = batch->seq_id + seq.offset;
98+
}
99+
}
100+
if (logits_all) {
101+
for (size_t i = 0; i < length; ++i) {
102+
ubatch.output[ubatch.n_tokens + i] = 1;
103+
out_ids.push_back(ids[seq.offset + i]);
104+
}
105+
} else if (batch->logits) {
106+
if (ubatch.equal_seqs) {
107+
for (size_t i = 0; i < length; ++i) {
108+
size_t id = ids[seq.offset + i];
109+
int8_t is_output = batch->logits[id];
110+
ubatch.output[ubatch.n_tokens + i] = is_output;
111+
if (is_output) { out_ids.push_back(id); }
112+
}
113+
} else {
114+
// simple split
115+
ubatch.output = batch->logits + seq.offset;
116+
for (size_t i = 0; i < length; ++i) {
117+
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
118+
}
119+
}
120+
} else {
121+
// only get last output
122+
for (size_t i = 0; i < length; ++i) {
123+
size_t id = ids[seq.offset + i];
124+
int8_t is_last = id == ids.size() - 1;
125+
ubatch.output[ubatch.n_tokens + i] = is_last;
126+
if (is_last) { out_ids.push_back(id); }
127+
}
128+
}
129+
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
130+
ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
131+
}
132+
ubatch.n_tokens += length;
133+
ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
134+
seq.offset += length;
135+
seq.length -= length;
136+
n_tokens -= length;
137+
GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
138+
}
139+
140+
llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
141+
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
142+
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
143+
ubatch.equal_seqs = false;
144+
if (!seq.empty()) {
145+
llama_sbatch_seq & s = seq[0];
146+
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
147+
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
148+
add_seq_to_ubatch(ubatch, s, length);
149+
}
150+
return ubatch;
151+
}
152+
153+
llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
154+
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
155+
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
156+
if (!seq.empty()) {
157+
size_t length = 0;
158+
size_t n_tokens_in_ubatch = 0;
159+
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
160+
// smallest first, because it's easier to split this way;
161+
// starting from the end to pop in constant time.
162+
for (size_t i = seq.size(); i-- > 0;) {
163+
llama_sbatch_seq & s = seq[i];
164+
GGML_ASSERT(s.length > 0);
165+
if (length == 0) {
166+
length = s.length < n_ubatch ? s.length : n_ubatch;
167+
}
168+
add_seq_to_ubatch(ubatch, s, length);
169+
n_tokens_in_ubatch += length;
170+
// shared prompts can't be mixed with any of their sequences,
171+
// so it's safer to compute them in their own ubatch
172+
if (s.n_seq_id > 1) { break; }
173+
// stop when there isn't enough space for another sequence
174+
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
175+
}
176+
}
177+
return ubatch;
178+
}
179+
180+
llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
181+
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
182+
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
183+
if (!seq.empty()) {
184+
llama_sbatch_seq & s = seq[seq.size() - 1];
185+
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
186+
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
187+
add_seq_to_ubatch(ubatch, s, length);
188+
}
189+
return ubatch;
190+
}
191+
192+
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193+
GGML_ASSERT(batch.n_tokens >= 0);
194+
this->batch = &batch;
195+
this->n_embd = n_embd;
196+
this->logits_all = logits_all;
197+
198+
n_tokens = batch.n_tokens;
199+
ids.resize(n_tokens);
200+
out_ids.clear();
201+
// TODO: reserve out_ids and seq
202+
203+
for (size_t i = 0; i < n_tokens; ++i) {
204+
ids[i] = i;
205+
}
206+
if (simple_split) {
207+
seq.resize(1);
208+
llama_sbatch_seq & s = seq[0];
209+
s.n_seq_id = 0;
210+
s.seq_id = nullptr;
211+
s.offset = 0;
212+
s.length = n_tokens;
213+
return;
214+
}
215+
std::sort(ids.begin(), ids.end(),
216+
[&batch](size_t a, size_t b) {
217+
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
218+
int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
219+
// sort by seq_id, then by pos
220+
if (n_seq_a == n_seq_b) {
221+
if (batch.seq_id) {
222+
for (int32_t i = 0; i < n_seq_a; ++i) {
223+
llama_seq_id seq_id_a = batch.seq_id[a][i];
224+
llama_seq_id seq_id_b = batch.seq_id[b][i];
225+
// smaller seq_ids go first
226+
if (seq_id_a != seq_id_b) {
227+
return seq_id_a < seq_id_b;
228+
}
229+
}
230+
}
231+
// when all else is equal, sort by pos
232+
if (batch.pos) {
233+
return batch.pos[a] < batch.pos[b];
234+
}
235+
// no pos, sort by id
236+
return a < b;
237+
}
238+
// shared prompts go first
239+
return n_seq_a > n_seq_b;
240+
}
241+
);
242+
// init seq
243+
llama_sbatch_seq * last_seq = nullptr;
244+
245+
for (size_t i = 0; i < n_tokens; ++i) {
246+
const size_t bi = ids[i];
247+
const int32_t n_seqs = batch.n_seq_id[bi];
248+
llama_seq_id * seq_ids = batch.seq_id[bi];
249+
if (last_seq != nullptr) {
250+
bool same = n_seqs == last_seq->n_seq_id;
251+
for (int32_t j = 0; same && j < n_seqs; ++j) {
252+
if (seq_ids[j] != last_seq->seq_id[j]) {
253+
same = false;
254+
}
255+
}
256+
if (same) {
257+
last_seq->length += 1;
258+
continue;
259+
}
260+
}
261+
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
262+
seq.push_back(new_seq);
263+
last_seq = &seq.back();
264+
}
265+
// keep shared prompts first at the end, then sort by length descending.
266+
std::sort(seq.begin(), seq.end(),
267+
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
268+
if (a.n_seq_id == b.n_seq_id) {
269+
return a.length > b.length;
270+
}
271+
return a.n_seq_id < b.n_seq_id;
272+
}
273+
);
274+
}
275+
276+
llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
277+
batch = in_batch;
278+
GGML_ASSERT(batch.n_tokens > 0);
279+
if (!batch.pos) {
280+
pos.resize(batch.n_tokens);
281+
for (int32_t i = 0; i < batch.n_tokens; i++) {
282+
pos[i] = i + p0;
283+
}
284+
batch.pos = pos.data();
285+
}
286+
if (!batch.n_seq_id) {
287+
n_seq_id.resize(batch.n_tokens);
288+
for (int32_t i = 0; i < batch.n_tokens; i++) {
289+
n_seq_id[i] = seq_id_0.size();
290+
}
291+
batch.n_seq_id = n_seq_id.data();
292+
}
293+
if (!batch.seq_id) {
294+
seq_id.resize(batch.n_tokens + 1);
295+
seq_id[batch.n_tokens] = NULL;
296+
for (int32_t i = 0; i < batch.n_tokens; i++) {
297+
seq_id[i] = seq_id_0.data();
298+
}
299+
batch.seq_id = seq_id.data();
300+
}
301+
if (!batch.logits) {
302+
logits.resize(batch.n_tokens);
303+
logits[logits.size() - 1] = true;
304+
batch.logits = logits.data();
305+
}
306+
}

0 commit comments

Comments
 (0)