Skip to content

Commit fa0e677

Browse files
committed
llama : extend batch API to select which logits to output
1 parent 897cacc commit fa0e677

File tree

4 files changed

+46
-6
lines changed

4 files changed

+46
-6
lines changed

examples/embd-input/embd-input-lib.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){
7979
if (n_eval > n_batch) {
8080
n_eval = n_batch;
8181
}
82-
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, };
82+
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, };
8383
if (llama_decode(ctx, batch, params.n_threads)) {
8484
fprintf(stderr, "%s : failed to eval\n", __func__);
8585
return false;

examples/parallel/parallel.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ int main(int argc, char ** argv) {
8282

8383
const int n_clients = 4;
8484

85+
// insert new requests as soon as the previous one is done
86+
const bool hot_swap = true;
87+
8588
#ifndef LOG_DISABLE_LOGS
8689
log_set_target(log_filename_generator("parallel", "log"));
8790
LOG_TEE("Log start\n");
@@ -121,14 +124,23 @@ int main(int argc, char ** argv) {
121124
std::vector<llama_token> batch_token;
122125
std::vector<llama_pos> batch_pos;
123126
std::vector<llama_seq_id> batch_seq_id;
127+
std::vector<int8_t> batch_logits;
124128
std::vector<client *> batch_clients;
125129

126-
while (true) {
130+
int32_t n_total_prompt = 0;
131+
int32_t n_total_gen = 0;
132+
133+
float t_avg = 0.0f;
134+
135+
const int32_t n_seq = 128;
136+
137+
while (g_seq_id < n_seq + n_clients) {
127138
uint32_t n_tokens = 0;
128139

129140
batch_token.clear();
130141
batch_pos.clear();
131142
batch_seq_id.clear();
143+
batch_logits.clear();
132144

133145
for (auto & client : clients) {
134146
if (client.seq_id == -1) {
@@ -138,6 +150,7 @@ int main(int argc, char ** argv) {
138150
batch_token.push_back(client.sampled);
139151
batch_pos.push_back(client.n_decoded);
140152
batch_seq_id.push_back(client.seq_id);
153+
batch_logits.push_back(true);
141154
batch_clients.push_back(&client);
142155
client.n_decoded += 1;
143156
client.i_batch = batch_token.size() - 1;
@@ -146,7 +159,9 @@ int main(int argc, char ** argv) {
146159
if (batch_token.empty()) {
147160
// all sequences have ended - clear the entire KV cache
148161
llama_kv_cache_rm_tokens(ctx, -1, -1);
162+
}
149163

164+
if (hot_swap || batch_token.empty()) {
150165
for (auto & client : clients) {
151166
if (client.seq_id == -1) {
152167
client.seq_id = g_seq_id;
@@ -166,7 +181,10 @@ int main(int argc, char ** argv) {
166181
batch_pos.push_back(i);
167182
batch_seq_id.push_back(client.seq_id);
168183
batch_clients.push_back(&client);
184+
batch_logits.push_back(false);
169185
}
186+
batch_logits.back() = true;
187+
170188
client.n_prompt = prompt_tokens.size();
171189
client.n_decoded = prompt_tokens.size();
172190
client.i_batch = batch_token.size() - 1;
@@ -186,6 +204,7 @@ int main(int argc, char ** argv) {
186204
nullptr,
187205
batch_pos.data() + i,
188206
batch_seq_id.data() + i,
207+
batch_logits.data() + i,
189208
0, 0, 0, // unused
190209
};
191210

@@ -232,14 +251,20 @@ int main(int argc, char ** argv) {
232251

233252
const auto t_main_end = ggml_time_us();
234253

235-
printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput: %s\nResponse: %s\n\n",
254+
printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput: %s\nResponse: %s\n\n",
236255
client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt,
256+
(t_main_end - client.t_start_prompt) / 1e6,
237257
(double) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt) * 1e6,
238258
(double) (client.n_decoded - client.n_prompt) / (t_main_end - client.t_start_gen) * 1e6,
239259
(double) (client.n_decoded ) / (t_main_end - client.t_start_prompt) * 1e6,
240260
::trim(client.input).c_str(),
241261
::trim(client.response).c_str());
242262

263+
n_total_prompt += client.n_prompt;
264+
n_total_gen += client.n_decoded - client.n_prompt;
265+
266+
t_avg += (t_main_end - client.t_start_prompt) / 1e6;
267+
243268
client.seq_id = -1;
244269
}
245270

@@ -248,6 +273,11 @@ int main(int argc, char ** argv) {
248273
}
249274
}
250275

276+
LOG_TEE("\n\n");
277+
LOG_TEE("Total prompt tokens: %d\n", n_total_prompt);
278+
LOG_TEE("Total gen tokens: %d\n", n_total_gen);
279+
LOG_TEE("Avg time per seq: %.2f s\n", t_avg / n_seq);
280+
251281
LOG_TEE("\n\n");
252282

253283
llama_print_timings(ctx);

llama.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4140,7 +4140,16 @@ static bool llama_eval_internal(
41404140

41414141
if (lctx.logits_all) {
41424142
logits_out.resize(n_vocab * n_tokens);
4143-
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
4143+
if (batch.logits) {
4144+
for (uint32_t i = 0; i < n_tokens; i++) {
4145+
if (batch.logits[i] == 0) {
4146+
continue;
4147+
}
4148+
memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab);
4149+
}
4150+
} else {
4151+
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
4152+
}
41444153
} else {
41454154
// return result for just the last token
41464155
logits_out.resize(n_vocab);
@@ -7318,7 +7327,7 @@ int llama_eval_embd(
73187327
int n_threads) {
73197328
llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1);
73207329

7321-
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, };
7330+
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, };
73227331

73237332
if (!llama_eval_internal(*ctx, batch, n_threads)) {
73247333
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
@@ -7346,6 +7355,7 @@ struct llama_batch llama_batch_get_one(
73467355
/*embd =*/ nullptr,
73477356
/*pos =*/ nullptr,
73487357
/*seq_id =*/ nullptr,
7358+
/*logits =*/ nullptr,
73497359
/*all_pos_0 =*/ pos_0,
73507360
/*all_pos_1 =*/ 1,
73517361
/*all_seq_id =*/ seq_id,

llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ extern "C" {
7070
typedef struct llama_batch {
7171
uint32_t n_tokens;
7272

73-
// TODO: not sure about these consts - might just get in the way all the time with no benefit
7473
const llama_token * token;
7574
const float * embd;
7675
const llama_pos * pos;
7776
const llama_seq_id * seq_id;
77+
const int8_t * logits; // if 0, do not extract logits for that token
7878

7979
// NOTE: helpers for smooth API transition - can be deprecated in the future
8080
// for future-proof code, use the above fields instead and ignore everything below

0 commit comments

Comments
 (0)