@@ -82,6 +82,9 @@ int main(int argc, char ** argv) {
82
82
83
83
const int n_clients = 4 ;
84
84
85
+ // insert new requests as soon as the previous one is done
86
+ const bool hot_swap = true ;
87
+
85
88
#ifndef LOG_DISABLE_LOGS
86
89
log_set_target (log_filename_generator (" parallel" , " log" ));
87
90
LOG_TEE (" Log start\n " );
@@ -121,14 +124,23 @@ int main(int argc, char ** argv) {
121
124
std::vector<llama_token> batch_token;
122
125
std::vector<llama_pos> batch_pos;
123
126
std::vector<llama_seq_id> batch_seq_id;
127
+ std::vector<int8_t > batch_logits;
124
128
std::vector<client *> batch_clients;
125
129
126
- while (true ) {
130
+ int32_t n_total_prompt = 0 ;
131
+ int32_t n_total_gen = 0 ;
132
+
133
+ float t_avg = 0 .0f ;
134
+
135
+ const int32_t n_seq = 128 ;
136
+
137
+ while (g_seq_id < n_seq + n_clients) {
127
138
uint32_t n_tokens = 0 ;
128
139
129
140
batch_token.clear ();
130
141
batch_pos.clear ();
131
142
batch_seq_id.clear ();
143
+ batch_logits.clear ();
132
144
133
145
for (auto & client : clients) {
134
146
if (client.seq_id == -1 ) {
@@ -138,6 +150,7 @@ int main(int argc, char ** argv) {
138
150
batch_token.push_back (client.sampled );
139
151
batch_pos.push_back (client.n_decoded );
140
152
batch_seq_id.push_back (client.seq_id );
153
+ batch_logits.push_back (true );
141
154
batch_clients.push_back (&client);
142
155
client.n_decoded += 1 ;
143
156
client.i_batch = batch_token.size () - 1 ;
@@ -146,7 +159,9 @@ int main(int argc, char ** argv) {
146
159
if (batch_token.empty ()) {
147
160
// all sequences have ended - clear the entire KV cache
148
161
llama_kv_cache_rm_tokens (ctx, -1 , -1 );
162
+ }
149
163
164
+ if (hot_swap || batch_token.empty ()) {
150
165
for (auto & client : clients) {
151
166
if (client.seq_id == -1 ) {
152
167
client.seq_id = g_seq_id;
@@ -166,7 +181,10 @@ int main(int argc, char ** argv) {
166
181
batch_pos.push_back (i);
167
182
batch_seq_id.push_back (client.seq_id );
168
183
batch_clients.push_back (&client);
184
+ batch_logits.push_back (false );
169
185
}
186
+ batch_logits.back () = true ;
187
+
170
188
client.n_prompt = prompt_tokens.size ();
171
189
client.n_decoded = prompt_tokens.size ();
172
190
client.i_batch = batch_token.size () - 1 ;
@@ -186,6 +204,7 @@ int main(int argc, char ** argv) {
186
204
nullptr ,
187
205
batch_pos.data () + i,
188
206
batch_seq_id.data () + i,
207
+ batch_logits.data () + i,
189
208
0 , 0 , 0 , // unused
190
209
};
191
210
@@ -232,14 +251,20 @@ int main(int argc, char ** argv) {
232
251
233
252
const auto t_main_end = ggml_time_us ();
234
253
235
- printf (" \033 [1mClient %2d, seq %4d, prompt %4d t, response %4d t, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033 [0m: \n\n Input: %s\n Response: %s\n\n " ,
254
+ printf (" \033 [1mClient %2d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033 [0m: \n\n Input: %s\n Response: %s\n\n " ,
236
255
client.id , client.seq_id , client.n_prompt , client.n_decoded - client.n_prompt ,
256
+ (t_main_end - client.t_start_prompt ) / 1e6 ,
237
257
(double ) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt ) * 1e6 ,
238
258
(double ) (client.n_decoded - client.n_prompt ) / (t_main_end - client.t_start_gen ) * 1e6 ,
239
259
(double ) (client.n_decoded ) / (t_main_end - client.t_start_prompt ) * 1e6 ,
240
260
::trim (client.input).c_str(),
241
261
::trim(client.response).c_str());
242
262
263
+ n_total_prompt += client.n_prompt ;
264
+ n_total_gen += client.n_decoded - client.n_prompt ;
265
+
266
+ t_avg += (t_main_end - client.t_start_prompt ) / 1e6 ;
267
+
243
268
client.seq_id = -1 ;
244
269
}
245
270
@@ -248,6 +273,11 @@ int main(int argc, char ** argv) {
248
273
}
249
274
}
250
275
276
+ LOG_TEE (" \n\n " );
277
+ LOG_TEE (" Total prompt tokens: %d\n " , n_total_prompt);
278
+ LOG_TEE (" Total gen tokens: %d\n " , n_total_gen);
279
+ LOG_TEE (" Avg time per seq: %.2f s\n " , t_avg / n_seq);
280
+
251
281
LOG_TEE (" \n\n " );
252
282
253
283
llama_print_timings (ctx);
0 commit comments