Skip to content

Commit 5b33b35

Browse files
committed
llama : scatter llama.cpp into multiple modules (wip)
1 parent a3c33b1 commit 5b33b35

19 files changed

+17168
-17086
lines changed

src/CMakeLists.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,16 @@ llama_add_compile_flags()
99
add_library(llama
1010
../include/llama.h
1111
llama.cpp
12-
llama-vocab.cpp
12+
llama-arch.cpp
13+
llama-batch.cpp
14+
llama-context.cpp
15+
llama-control-vector.cpp
1316
llama-grammar.cpp
17+
llama-kv-cache.cpp
18+
llama-mmap.cpp
19+
llama-model.cpp
1420
llama-sampling.cpp
21+
llama-vocab.cpp
1522
unicode.h
1623
unicode.cpp
1724
unicode-data.cpp

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#include "llama-arch.h"

src/llama-arch.h

Lines changed: 1714 additions & 0 deletions
Large diffs are not rendered by default.

src/llama-batch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#include "llama-batch.h"

src/llama-batch.h

Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
#pragma once
2+
3+
#include "llama.h"
4+
5+
#include <vector>
6+
7+
// very similar to llama_batch,
8+
// but has more metadata about sequences
9+
struct llama_ubatch {
10+
bool equal_seqs;
11+
// TODO: whole_seqs for embeddings?
12+
13+
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
14+
uint32_t n_seq_tokens; // tokens per sequence
15+
uint32_t n_seqs;
16+
17+
llama_token * token; // [n_tokens]
18+
float * embd; // [n_embd, n_tokens]
19+
llama_pos * pos; // [n_tokens]
20+
int32_t * n_seq_id; // [n_seqs]
21+
llama_seq_id ** seq_id; // [n_seqs]
22+
int8_t * output; // [n_tokens]
23+
};
24+
25+
struct llama_sbatch_seq {
26+
int32_t n_seq_id;
27+
llama_seq_id * seq_id;
28+
size_t offset;
29+
size_t length;
30+
};
31+
32+
// sequence-length-aware batch splitting
33+
struct llama_sbatch {
34+
// tokens left in this batch
35+
size_t n_tokens;
36+
37+
size_t n_embd;
38+
39+
bool logits_all; // TODO: remove once lctx.logits_all is removed too
40+
41+
// sorted indices into the batch
42+
std::vector<size_t> ids;
43+
// batch indices of the output
44+
std::vector<size_t> out_ids;
45+
std::vector<llama_sbatch_seq> seq;
46+
47+
const llama_batch * batch = nullptr;
48+
49+
// buffers for the ubatch
50+
std::vector<llama_token> ubatch_token;
51+
std::vector<float> ubatch_embd;
52+
std::vector<llama_pos> ubatch_pos;
53+
std::vector<int32_t> ubatch_n_seq_id;
54+
std::vector<llama_seq_id *> ubatch_seq_id;
55+
std::vector<int8_t> ubatch_output;
56+
57+
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false) {
58+
// clear empty sequences
59+
// the previous ubatch is assumed to be gone,
60+
// so nothing should refer to values in these sequences anymore.
61+
for (size_t i = seq.size(); i-- > 0;) {
62+
if (seq[i].length == 0) {
63+
seq.pop_back();
64+
} else {
65+
break;
66+
}
67+
}
68+
ubatch_token.resize(!has_embd ? n_ubatch : 0);
69+
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
70+
ubatch_pos.resize(n_ubatch);
71+
ubatch_n_seq_id.resize(n_ubatch);
72+
ubatch_seq_id.resize(n_ubatch);
73+
ubatch_output.resize(n_ubatch);
74+
llama_ubatch ubatch = {
75+
/*equal_seqs =*/ true,
76+
/*n_tokens =*/ 0,
77+
/*n_seq_tokens =*/ 0,
78+
/*n_seqs =*/ 0,
79+
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
80+
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
81+
/*pos =*/ ubatch_pos.data(),
82+
/*n_seq_id =*/ ubatch_n_seq_id.data(),
83+
/*seq_id =*/ ubatch_seq_id.data(),
84+
/*output =*/ ubatch_output.data(),
85+
};
86+
return ubatch;
87+
}
88+
89+
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
90+
GGML_ASSERT(batch != nullptr);
91+
GGML_ASSERT(length <= seq.length);
92+
// Can only add sequences of equal lengths to a batch,
93+
// otherwise it isn't clear to which sequence a token belongs
94+
GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
95+
GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
96+
// NOTE: loops are separated for cache-friendliness
97+
if (batch->token) {
98+
if (ubatch.equal_seqs) {
99+
for (size_t i = 0; i < length; ++i) {
100+
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
101+
}
102+
} else {
103+
// simple split
104+
ubatch.token = batch->token + seq.offset;
105+
}
106+
} else {
107+
ubatch.token = nullptr;
108+
}
109+
if (batch->embd) {
110+
if (ubatch.equal_seqs) {
111+
for (size_t i = 0; i < length; ++i) {
112+
memcpy(
113+
ubatch.embd + n_embd * (ubatch.n_tokens + i),
114+
batch->embd + n_embd * ids[seq.offset + i],
115+
n_embd * sizeof(float)
116+
);
117+
}
118+
} else {
119+
// simple split
120+
ubatch.embd = batch->embd + (n_embd * seq.offset);
121+
}
122+
} else {
123+
ubatch.embd = nullptr;
124+
}
125+
if (ubatch.equal_seqs) {
126+
for (size_t i = 0; i < length; ++i) {
127+
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
128+
}
129+
} else {
130+
// simple split
131+
ubatch.pos = batch->pos + seq.offset;
132+
}
133+
if (ubatch.equal_seqs) {
134+
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
135+
if (seq.seq_id) {
136+
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
137+
}
138+
} else {
139+
// simple split
140+
if (batch->n_seq_id) {
141+
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
142+
} else {
143+
for (size_t i = 0; i < length; ++i) {
144+
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
145+
}
146+
}
147+
if (batch->seq_id) {
148+
ubatch.seq_id = batch->seq_id + seq.offset;
149+
}
150+
}
151+
if (logits_all) {
152+
for (size_t i = 0; i < length; ++i) {
153+
ubatch.output[ubatch.n_tokens + i] = 1;
154+
out_ids.push_back(ids[seq.offset + i]);
155+
}
156+
} else if (batch->logits) {
157+
if (ubatch.equal_seqs) {
158+
for (size_t i = 0; i < length; ++i) {
159+
size_t id = ids[seq.offset + i];
160+
int8_t is_output = batch->logits[id];
161+
ubatch.output[ubatch.n_tokens + i] = is_output;
162+
if (is_output) { out_ids.push_back(id); }
163+
}
164+
} else {
165+
// simple split
166+
ubatch.output = batch->logits + seq.offset;
167+
for (size_t i = 0; i < length; ++i) {
168+
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
169+
}
170+
}
171+
} else {
172+
// only get last output
173+
for (size_t i = 0; i < length; ++i) {
174+
size_t id = ids[seq.offset + i];
175+
int8_t is_last = id == ids.size() - 1;
176+
ubatch.output[ubatch.n_tokens + i] = is_last;
177+
if (is_last) { out_ids.push_back(id); }
178+
}
179+
}
180+
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
181+
ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
182+
}
183+
ubatch.n_tokens += length;
184+
ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
185+
seq.offset += length;
186+
seq.length -= length;
187+
n_tokens -= length;
188+
GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
189+
}
190+
191+
// simple split, unknown number of sequences of unequal lengths
192+
llama_ubatch split_simple(size_t n_ubatch) {
193+
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
194+
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
195+
ubatch.equal_seqs = false;
196+
if (!seq.empty()) {
197+
llama_sbatch_seq & s = seq[0];
198+
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
199+
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
200+
add_seq_to_ubatch(ubatch, s, length);
201+
}
202+
return ubatch;
203+
}
204+
205+
// make batches of equal-length sequences
206+
llama_ubatch split_equal(size_t n_ubatch) {
207+
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
208+
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
209+
if (!seq.empty()) {
210+
size_t length = 0;
211+
size_t n_tokens_in_ubatch = 0;
212+
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
213+
// smallest first, because it's easier to split this way;
214+
// starting from the end to pop in constant time.
215+
for (size_t i = seq.size(); i-- > 0;) {
216+
llama_sbatch_seq & s = seq[i];
217+
GGML_ASSERT(s.length > 0);
218+
if (length == 0) {
219+
length = s.length < n_ubatch ? s.length : n_ubatch;
220+
}
221+
add_seq_to_ubatch(ubatch, s, length);
222+
n_tokens_in_ubatch += length;
223+
// shared prompts can't be mixed with any of their sequences,
224+
// so it's safer to compute them in their own ubatch
225+
if (s.n_seq_id > 1) { break; }
226+
// stop when there isn't enough space for another sequence
227+
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
228+
}
229+
}
230+
return ubatch;
231+
}
232+
233+
// sequence-wise split
234+
llama_ubatch split_seq(size_t n_ubatch) {
235+
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
236+
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
237+
if (!seq.empty()) {
238+
llama_sbatch_seq & s = seq[seq.size() - 1];
239+
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
240+
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
241+
add_seq_to_ubatch(ubatch, s, length);
242+
}
243+
return ubatch;
244+
}
245+
246+
void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) {
247+
GGML_ASSERT(batch.n_tokens >= 0);
248+
this->batch = &batch;
249+
this->n_embd = n_embd;
250+
this->logits_all = logits_all;
251+
252+
n_tokens = batch.n_tokens;
253+
ids.resize(n_tokens);
254+
out_ids.clear();
255+
// TODO: reserve out_ids and seq
256+
257+
for (size_t i = 0; i < n_tokens; ++i) {
258+
ids[i] = i;
259+
}
260+
if (simple_split) {
261+
seq.resize(1);
262+
llama_sbatch_seq & s = seq[0];
263+
s.n_seq_id = 0;
264+
s.seq_id = nullptr;
265+
s.offset = 0;
266+
s.length = n_tokens;
267+
return;
268+
}
269+
std::sort(ids.begin(), ids.end(),
270+
[&batch](size_t a, size_t b) {
271+
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
272+
int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
273+
// sort by seq_id, then by pos
274+
if (n_seq_a == n_seq_b) {
275+
if (batch.seq_id) {
276+
for (int32_t i = 0; i < n_seq_a; ++i) {
277+
llama_seq_id seq_id_a = batch.seq_id[a][i];
278+
llama_seq_id seq_id_b = batch.seq_id[b][i];
279+
// smaller seq_ids go first
280+
if (seq_id_a != seq_id_b) {
281+
return seq_id_a < seq_id_b;
282+
}
283+
}
284+
}
285+
// when all else is equal, sort by pos
286+
if (batch.pos) {
287+
return batch.pos[a] < batch.pos[b];
288+
}
289+
// no pos, sort by id
290+
return a < b;
291+
}
292+
// shared prompts go first
293+
return n_seq_a > n_seq_b;
294+
}
295+
);
296+
// init seq
297+
llama_sbatch_seq * last_seq = nullptr;
298+
299+
for (size_t i = 0; i < n_tokens; ++i) {
300+
const size_t bi = ids[i];
301+
const int32_t n_seqs = batch.n_seq_id[bi];
302+
llama_seq_id * seq_ids = batch.seq_id[bi];
303+
if (last_seq != nullptr) {
304+
bool same = n_seqs == last_seq->n_seq_id;
305+
for (int32_t j = 0; same && j < n_seqs; ++j) {
306+
if (seq_ids[j] != last_seq->seq_id[j]) {
307+
same = false;
308+
}
309+
}
310+
if (same) {
311+
last_seq->length += 1;
312+
continue;
313+
}
314+
}
315+
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
316+
seq.push_back(new_seq);
317+
last_seq = &seq.back();
318+
}
319+
// keep shared prompts first at the end, then sort by length descending.
320+
std::sort(seq.begin(), seq.end(),
321+
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
322+
if (a.n_seq_id == b.n_seq_id) {
323+
return a.length > b.length;
324+
}
325+
return a.n_seq_id < b.n_seq_id;
326+
}
327+
);
328+
}
329+
};
330+

0 commit comments

Comments
 (0)