Skip to content

Commit a207561

Browse files
committed
examples : add example for batched decoding
1 parent d008733 commit a207561

File tree

8 files changed

+315
-125
lines changed

8 files changed

+315
-125
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ models-mnt
5151
/save-load-state
5252
/server
5353
/simple
54+
/batched
5455
/speculative
5556
/parallel
5657
/train-text-from-scratch

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Define the default target now so that it is always the first target
2-
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative parallel tests/test-c.o
2+
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple batched save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative parallel tests/test-c.o
33

44
# Binaries only useful for tests
55
TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama
@@ -519,6 +519,9 @@ main: examples/main/main.cpp build-info.h ggml.
519519
simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS)
520520
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
521521

522+
batched: examples/batched/batched.cpp build-info.h ggml.o llama.o common.o $(OBJS)
523+
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
524+
522525
quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS)
523526
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
524527

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ else()
2323
add_subdirectory(train-text-from-scratch)
2424
add_subdirectory(convert-llama2c-to-ggml)
2525
add_subdirectory(simple)
26+
add_subdirectory(batched)
2627
add_subdirectory(speculative)
2728
add_subdirectory(parallel)
2829
add_subdirectory(embd-input)

examples/batched/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET batched)
2+
add_executable(${TARGET} batched.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/batched/README.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# llama.cpp/example/batched
2+
3+
The example demonstrates batched generation from a given prompt
4+
5+
```bash
6+
./batched ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is" 4
7+
8+
...
9+
10+
main: n_len = 32, n_ctx = 2048, n_parallel = 4, n_kv_req = 113
11+
12+
Hello my name is
13+
14+
main: generating 4 sequences ...
15+
16+
main: stream 0 finished
17+
main: stream 1 finished
18+
main: stream 2 finished
19+
main: stream 3 finished
20+
21+
sequence 0:
22+
23+
Hello my name is Shirley. I am a 25-year-old female who has been working for over 5 years as a b
24+
25+
sequence 1:
26+
27+
Hello my name is Renee and I'm a 32 year old female from the United States. I'm looking for a man between
28+
29+
sequence 2:
30+
31+
Hello my name is Diana. I am looking for a housekeeping job. I have experience with children and have my own transportation. I am
32+
33+
sequence 3:
34+
35+
Hello my name is Cody. I am a 3 year old neutered male. I am a very friendly cat. I am very playful and
36+
37+
main: decoded 108 tokens in 3.57 s, speed: 30.26 t/s
38+
39+
llama_print_timings: load time = 587.00 ms
40+
llama_print_timings: sample time = 2.56 ms / 112 runs ( 0.02 ms per token, 43664.72 tokens per second)
41+
llama_print_timings: prompt eval time = 4089.11 ms / 118 tokens ( 34.65 ms per token, 28.86 tokens per second)
42+
llama_print_timings: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second)
43+
llama_print_timings: total time = 4156.04 ms
44+
```

examples/batched/batched.cpp

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
#include "common.h"
2+
#include "llama.h"
3+
4+
#include <cmath>
5+
#include <cstdio>
6+
#include <string>
7+
#include <vector>
8+
9+
int main(int argc, char ** argv) {
10+
gpt_params params;
11+
12+
if (argc == 1 || argv[1][0] == '-') {
13+
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL]\n" , argv[0]);
14+
return 1 ;
15+
}
16+
17+
int n_parallel = 1;
18+
19+
if (argc >= 2) {
20+
params.model = argv[1];
21+
}
22+
23+
if (argc >= 3) {
24+
params.prompt = argv[2];
25+
}
26+
27+
if (argc >= 4) {
28+
n_parallel = std::atoi(argv[3]);
29+
}
30+
31+
if (params.prompt.empty()) {
32+
params.prompt = "Hello my name is";
33+
}
34+
35+
// total length of the sequences including the prompt
36+
const int n_len = 32;
37+
38+
// init LLM
39+
40+
llama_backend_init(params.numa);
41+
42+
llama_context_params ctx_params = llama_context_default_params();
43+
44+
ctx_params.seed = 1234;
45+
ctx_params.n_ctx = 2048;
46+
47+
llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
48+
49+
if (model == NULL) {
50+
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
51+
return 1;
52+
}
53+
54+
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
55+
56+
if (ctx == NULL) {
57+
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
58+
return 1;
59+
}
60+
61+
// tokenize the prompt
62+
63+
std::vector<llama_token> tokens_list;
64+
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
65+
66+
const int n_ctx = llama_n_ctx(ctx);
67+
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
68+
69+
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_parallel, n_kv_req);
70+
71+
// make sure the KV cache is big enough to hold all the prompt and generated tokens
72+
if (n_kv_req > n_ctx) {
73+
LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__);
74+
LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__);
75+
return 1;
76+
}
77+
78+
// print the prompt token-by-token
79+
80+
fprintf(stderr, "\n");
81+
82+
for (auto id : tokens_list) {
83+
fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
84+
}
85+
86+
fflush(stderr);
87+
88+
// create a llama_batch with size 512
89+
// we use this object to submit token data for decoding
90+
91+
llama_batch batch = llama_batch_init(512, 0);
92+
93+
// evaluate the initial prompt
94+
batch.n_tokens = tokens_list.size();
95+
96+
for (int32_t i = 0; i < batch.n_tokens; i++) {
97+
batch.token[i] = tokens_list[i];
98+
batch.pos[i] = i;
99+
batch.seq_id[i] = 0;
100+
batch.logits[i] = false;
101+
}
102+
103+
// llama_decode will output logits only for the last token of the prompt
104+
batch.logits[batch.n_tokens - 1] = true;
105+
106+
if (llama_decode(ctx, batch, params.n_threads) != 0) {
107+
LOG_TEE("%s: llama_decode() failed\n", __func__);
108+
return 1;
109+
}
110+
111+
// assign the system KV cache to all parallel sequences
112+
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
113+
for (int32_t i = 1; i < n_parallel; ++i) {
114+
llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens);
115+
}
116+
117+
if (n_parallel > 1) {
118+
LOG_TEE("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);
119+
}
120+
121+
// main loop
122+
123+
// we will store the parallel decoded sequences in this vector
124+
std::vector<std::string> streams(n_parallel);
125+
126+
// remember the batch index of the last token for each parallel sequence
127+
// we need this to determine which logits to sample from
128+
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
129+
130+
int n_cur = batch.n_tokens;
131+
int n_decode = 0;
132+
133+
const auto t_main_start = ggml_time_us();
134+
135+
while (n_cur <= n_len) {
136+
// evaluate the current batch with the transformer model
137+
if (llama_decode(ctx, batch, params.n_threads)) {
138+
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
139+
return 1;
140+
}
141+
142+
// prepare the next batch
143+
batch.n_tokens = 0;
144+
145+
// sample the next token for each parallel sequence / stream
146+
for (int32_t i = 0; i < n_parallel; ++i) {
147+
if (i_batch[i] < 0) {
148+
// the stream has already finished
149+
continue;
150+
}
151+
152+
auto n_vocab = llama_n_vocab(ctx);
153+
auto logits = llama_get_logits_ith(ctx, i_batch[i]);
154+
155+
std::vector<llama_token_data> candidates;
156+
candidates.reserve(n_vocab);
157+
158+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
159+
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
160+
}
161+
162+
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
163+
164+
const int top_k = 40;
165+
const float top_p = 0.9f;
166+
const float temp = 0.4f;
167+
168+
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
169+
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
170+
llama_sample_temp (ctx, &candidates_p, temp);
171+
172+
const llama_token new_token_id = llama_sample_token(ctx, &candidates_p);
173+
174+
//const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
175+
176+
// is it an end of stream? -> mark the stream as finished
177+
if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) {
178+
i_batch[i] = -1;
179+
LOG_TEE("\n");
180+
if (n_parallel > 1) {
181+
LOG_TEE("%s: stream %d finished", __func__, i);
182+
}
183+
184+
continue;
185+
}
186+
187+
// if there is only one stream, we print immediately to stdout
188+
if (n_parallel == 1) {
189+
LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
190+
fflush(stdout);
191+
}
192+
193+
streams[i] += llama_token_to_piece(ctx, new_token_id);
194+
195+
// push this new token for next evaluation
196+
batch.token [batch.n_tokens] = new_token_id;
197+
batch.pos [batch.n_tokens] = n_cur;
198+
batch.seq_id[batch.n_tokens] = i;
199+
batch.logits[batch.n_tokens] = true;
200+
201+
i_batch[i] = batch.n_tokens;
202+
203+
batch.n_tokens += 1;
204+
205+
n_decode += 1;
206+
}
207+
208+
// all streams are finished
209+
if (batch.n_tokens == 0) {
210+
break;
211+
}
212+
213+
n_cur += 1;
214+
}
215+
216+
LOG_TEE("\n");
217+
218+
if (n_parallel > 1) {
219+
LOG_TEE("\n");
220+
221+
for (int32_t i = 0; i < n_parallel; ++i) {
222+
LOG_TEE("sequence %d:\n\n%s%s\n\n", i, params.prompt.c_str(), streams[i].c_str());
223+
}
224+
}
225+
226+
const auto t_main_end = ggml_time_us();
227+
228+
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
229+
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
230+
231+
llama_print_timings(ctx);
232+
233+
fprintf(stderr, "\n");
234+
235+
llama_batch_free(batch);
236+
237+
llama_free(ctx);
238+
llama_free_model(model);
239+
240+
llama_backend_free();
241+
242+
return 0;
243+
}

examples/simple/README.md

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
# llama.cpp/example/simple
22

33
The purpose of this example is to demonstrate a minimal usage of llama.cpp for generating text with a given prompt.
4-
The example demonstrates single-batch as well as parallel generation.
5-
6-
## Single-batch generation
74

85
```bash
9-
./simple ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is" 1
6+
./simple ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is"
107

118
...
129

@@ -22,46 +19,3 @@ llama_print_timings: prompt eval time = 655.63 ms / 10 tokens ( 65.56 ms
2219
llama_print_timings: eval time = 2180.97 ms / 27 runs ( 80.78 ms per token, 12.38 tokens per second)
2320
llama_print_timings: total time = 2891.13 ms
2421
```
25-
26-
## Parallel generation
27-
28-
```bash
29-
./simple ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is" 4
30-
31-
...
32-
33-
main: n_len = 32, n_ctx = 2048, n_parallel = 4, n_kv_req = 113
34-
35-
Hello my name is
36-
37-
main: generating 4 sequences ...
38-
39-
main: stream 0 finished
40-
main: stream 1 finished
41-
main: stream 2 finished
42-
main: stream 3 finished
43-
44-
sequence 0:
45-
46-
Hello my name is Shirley. I am a 25-year-old female who has been working for over 5 years as a b
47-
48-
sequence 1:
49-
50-
Hello my name is Renee and I'm a 32 year old female from the United States. I'm looking for a man between
51-
52-
sequence 2:
53-
54-
Hello my name is Diana. I am looking for a housekeeping job. I have experience with children and have my own transportation. I am
55-
56-
sequence 3:
57-
58-
Hello my name is Cody. I am a 3 year old neutered male. I am a very friendly cat. I am very playful and
59-
60-
main: decoded 108 tokens in 3.57 s, speed: 30.26 t/s
61-
62-
llama_print_timings: load time = 587.00 ms
63-
llama_print_timings: sample time = 2.56 ms / 112 runs ( 0.02 ms per token, 43664.72 tokens per second)
64-
llama_print_timings: prompt eval time = 4089.11 ms / 118 tokens ( 34.65 ms per token, 28.86 tokens per second)
65-
llama_print_timings: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second)
66-
llama_print_timings: total time = 4156.04 ms
67-
```

0 commit comments

Comments
 (0)