10
10
// The module takes in a string as input and emits a string as output.
11
11
12
12
#include < executorch/examples/models/llama2/runner/runner.h>
13
+ #include < executorch/extension/evalue_util/print_evalue.h>
13
14
#include < executorch/extension/runner_util/managed_tensor.h>
14
15
15
16
#include < ctime>
@@ -121,24 +122,6 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) {
121
122
return res;
122
123
}
123
124
124
- std::vector<exec_aten::SizesType> Runner::getKVCacheShape () {
125
- // shape: (n_layers, args.max_batch_size, args.max_seq_len, self.n_kv_heads,
126
- // self.head_dim)
127
- std::vector<std::string> methods = {
128
- " get_n_layers" ,
129
- " get_max_batch_size" ,
130
- " get_max_seq_len" ,
131
- " get_n_kv_heads" ,
132
- " get_head_dim" };
133
- std::vector<int64_t > default_values = {12 , 1 , 128 , 32 , 128 };
134
- std::vector<exec_aten::SizesType> result;
135
- for (int i = 0 ; i < methods.size (); ++i) {
136
- // convert from int64_t to int32_t
137
- result.push_back (getMetadataHelper<int64_t >(methods[i], default_values[i]));
138
- }
139
- return result;
140
- }
141
-
142
125
template <typename T>
143
126
int32_t Runner::logitsToToken (
144
127
const exec_aten::Tensor& logits_tensor,
@@ -155,6 +138,73 @@ int32_t Runner::logitsToToken(
155
138
return sampler_->sample (logits_last);
156
139
}
157
140
141
+ // Given an input token. Set up the inputs for the model and execute a single
142
+ // step. Returning the logits tensor.
143
+ Result<torch::executor::Tensor> Runner::run_model_step (
144
+ int64_t input_token,
145
+ ManagedTensor& managed_tokens,
146
+ ManagedTensor& managed_start_pos,
147
+ size_t max_seq_len) {
148
+ // ET_LOG(Info, "Input token %" PRIu64, input_token);
149
+ if (use_kv_cache_) {
150
+ std::vector<EValue> inputs;
151
+ auto tokens = managed_tokens.get_aliasing_tensor ();
152
+ auto start_pos = managed_start_pos.get_aliasing_tensor ();
153
+
154
+ // When using kv-cache our input is always 1 token, so just update to the
155
+ // latest.
156
+ tokens.mutable_data_ptr <int64_t >()[0 ] = input_token;
157
+
158
+ // inputs:[tokens, start_pos]
159
+ inputs.push_back (tokens);
160
+ inputs.push_back (start_pos);
161
+
162
+ Result<std::vector<EValue>> outputs_res = module_->forward (inputs);
163
+ ET_CHECK_OK_OR_RETURN_ERROR (outputs_res.error ());
164
+ ET_CHECK_MSG (
165
+ outputs_res.get ().size () == 1 ,
166
+ " More then one output returned from executing LLM." );
167
+ ET_CHECK_MSG (
168
+ outputs_res.get ()[0 ].isTensor (),
169
+ " Non Tensor Output returned from executing LLM" );
170
+
171
+ // Bump start_pos by 1
172
+ start_pos.mutable_data_ptr <int64_t >()[0 ]++;
173
+
174
+ // Return the logits tensor
175
+ return outputs_res.get ()[0 ].toTensor ();
176
+ } else { // no kv cache
177
+ std::vector<EValue> inputs;
178
+ auto tokens = managed_tokens.get_aliasing_tensor ();
179
+ (void )managed_start_pos; // unused
180
+
181
+ // When not using kv-cache our input is the entire history of tokens we have
182
+ // seen, so resize input to be 1 larger and append the new token to the end.
183
+ // TODO does this work in ATen mode?
184
+ tokens.mutable_data_ptr <int64_t >()[tokens.size (1 ) - 1 ] = input_token;
185
+
186
+ // inputs:[tokens]
187
+ inputs.push_back (tokens);
188
+
189
+ Result<std::vector<EValue>> outputs_res = module_->forward (inputs);
190
+ ET_CHECK_OK_OR_RETURN_ERROR (outputs_res.error ());
191
+ ET_CHECK_MSG (
192
+ outputs_res.get ().size () == 1 ,
193
+ " More then one output returned from executing LLM." );
194
+ ET_CHECK_MSG (
195
+ outputs_res.get ()[0 ].isTensor (),
196
+ " Non Tensor Output returned from executing LLM" );
197
+
198
+ if (tokens.size (1 ) < max_seq_len) {
199
+ // Resize the tokens tensor to be 1 larger for next step.
200
+ managed_tokens.resize ({1 , static_cast <int >(tokens.size (1 ) + 1 )});
201
+ }
202
+
203
+ // Return the logits tensor
204
+ return outputs_res.get ()[0 ].toTensor ();
205
+ }
206
+ }
207
+
158
208
Error Runner::generate (
159
209
const std::string& prompt,
160
210
int32_t seq_len,
@@ -189,9 +239,6 @@ Error Runner::generate(
189
239
prompt_tokens,
190
240
&num_prompt_tokens);
191
241
192
- for (int i = 0 ; i < num_prompt_tokens; i++) {
193
- ET_LOG (Info, " prompt_tokens[%d]: %d" , i, prompt_tokens[i]);
194
- }
195
242
ET_CHECK_MSG (num_prompt_tokens >= 1 , " Expected at least 1 prompt token" );
196
243
ET_CHECK_MSG (
197
244
num_prompt_tokens < max_seq_len_,
@@ -202,89 +249,94 @@ Error Runner::generate(
202
249
" Sequence length exceeded - please increase the seq_len value passed to generate()" );
203
250
204
251
// start the main loop
205
- int next; // will store the next token in the sequence
206
- int64_t pos = num_prompt_tokens - 1 ; // position in the sequence
207
- int token = prompt_tokens[pos]; // prefill starts from 0 to num_prompt_tokens
208
- int logits_index = 0 ; // index of the logits tensor in the output
209
- std::vector<exec_aten::SizesType> input_shape = {1 , 1 };
210
- std::vector<exec_aten::SizesType> pos_shape = {1 };
252
+ int64_t pos = 0 ; // position in the sequence
253
+
211
254
std::vector<int64_t > token_data; // allocate space for the tokens
212
- std::vector<int64_t > pos_data; // allocate space for the tokens
255
+ std::vector<exec_aten::SizesType> token_shape = {1 , seq_len};
256
+
257
+ std::vector<int64_t > start_pos_data; // allocate space for the tokens
258
+ std::vector<exec_aten::SizesType> start_pos_shape = {1 };
213
259
214
260
if (use_kv_cache_) {
215
- // set pos to 0, refill token by token
216
- pos = 0 ;
261
+ // hard code these to size 1 as kv cache is locked to static size right now.
217
262
token_data.resize (1 );
218
- pos_data.resize (seq_len);
263
+ token_shape[1 ] = 1 ;
264
+ start_pos_data.resize (1 );
265
+ start_pos_data.push_back (0 );
219
266
} else {
220
- // reserve data for tokens, notice the size is still 0.
267
+ // reserve data for tokens, notice the size is still 0 but the capacity is
268
+ // seq_len.
221
269
token_data.resize (seq_len);
222
270
}
223
271
224
272
// initialize tensor wrappers
225
- ManagedTensor pos_managed (
226
- pos_data.data (), pos_data.size (), pos_shape, ScalarType::Long);
227
-
228
- // copy prompt tokens into data
229
- for (int i = 0 ; i <= pos; ++i) {
230
- // @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
231
- token_data[i] = prompt_tokens[i];
232
- if (i > 0 ) {
233
- printf (
234
- " %s" ,
235
- ET_UNWRAP (
236
- tokenizer_->decode (prompt_tokens[i - 1 ], prompt_tokens[i])));
273
+ ManagedTensor tokens_managed (
274
+ token_data.data (),
275
+ 128 , // TODO clean up unused 128 here as ManagedTensor ignores this arg in
276
+ // ctor
277
+ token_shape,
278
+ ScalarType::Long);
279
+ // Create with the max shape to approapriately set the capacity of this
280
+ // tensor, then resize back to 1 for first input.
281
+ tokens_managed.resize ({1 , 1 });
282
+
283
+ ManagedTensor start_pos_managed (
284
+ start_pos_data.data (), 128 , start_pos_shape, ScalarType::Long);
285
+
286
+ int64_t prev_token;
287
+ int64_t cur_token = prompt_tokens[0 ];
288
+
289
+ // If we arent using the kv cache then we can batch prefill the prompt
290
+ if (!use_kv_cache_) {
291
+ tokens_managed.resize ({1 , num_prompt_tokens});
292
+ for (int i = 0 ; i < num_prompt_tokens - 1 ; i++) {
293
+ tokens_managed.get_aliasing_tensor ().mutable_data_ptr <int64_t >()[i] =
294
+ prompt_tokens[i];
295
+ }
296
+ // prefill tokens up to the last prompt token and then enter the loop with
297
+ // the last promp token as the current token.
298
+ cur_token = prompt_tokens[num_prompt_tokens - 1 ];
299
+ pos = num_prompt_tokens - 1 ;
300
+
301
+ // Print the prompt for consistent output between single token prefill and
302
+ // batch prefill.
303
+ int prev = prompt_tokens[0 ];
304
+ int cur;
305
+ for (int i = 1 ; i < num_prompt_tokens; i++) {
306
+ cur = prompt_tokens[i];
307
+ auto piece_res = tokenizer_->decode (prev, cur);
308
+ ET_CHECK_OK_OR_RETURN_ERROR (piece_res.error ());
309
+ util::safe_printf (piece_res.get ());
310
+ fflush (stdout);
311
+ prev = cur;
237
312
}
238
313
}
239
314
240
- // create a 1xN int tensor with next as value
241
- while (pos + 1 < seq_len) {
242
- // ET_LOG(Info, "Generating step %d...", pos);
243
- // set the current token in the tensor
244
- std::vector<EValue> inputs;
245
- if (use_kv_cache_) {
246
- token_data[0 ] = token;
247
- input_shape[1 ] = 1 ;
248
- // inputs: [tokens, start_pos, k_cache, v_cache]
249
- inputs.emplace_back (pos_managed.get_aliasing_tensor ());
250
- } else {
251
- // @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
252
- token_data[pos] = token;
253
- input_shape[1 ] = pos + 1 ;
254
- }
255
- ManagedTensor token_managed (
256
- token_data.data (), token_data.size (), input_shape, ScalarType::Long);
257
- inputs.insert (inputs.begin (), token_managed.get_aliasing_tensor ());
258
- // For kv cache, inputs: [tokens, start_pos, k_cache, v_cache]
259
- // Otherwise inputs: [tokens]
260
- Result<std::vector<EValue>> outputs_res = module_->forward (inputs);
261
- ET_CHECK_MSG (
262
- outputs_res.ok (),
263
- " Execution of method forward failed with status 0x%" PRIx32,
264
- static_cast <int32_t >(outputs_res.error ()));
265
- // ET_LOG(Info, "Model executed successfully.");
315
+ // Generate our tokens
316
+ while (pos < seq_len - 1 ) {
317
+ // Run the model
318
+ Result<torch::executor::Tensor> logits_res =
319
+ run_model_step (cur_token, tokens_managed, start_pos_managed, seq_len);
266
320
267
- std::vector<EValue> outputs = outputs_res.get ();
268
- // Check the outputs.
269
- ET_CHECK_MSG (
270
- outputs.size () == 1 && outputs.at (0 ).isTensor (),
271
- " Expecting output to have exactly 1 tensor output. Got %zu outputs." ,
272
- outputs.size ());
273
321
if (pos == num_prompt_tokens) {
274
322
timers_.first_token_ms = util::time_in_ms ();
275
323
} else if (pos == num_prompt_tokens - 1 ) {
276
324
timers_.prompt_eval_end_ms = util::time_in_ms ();
277
325
}
278
- int32_t next_tok;
279
- exec_aten::Tensor logits_tensor = outputs.at (logits_index).toTensor ();
326
+
327
+ ET_CHECK_OK_OR_RETURN_ERROR (logits_res.error ());
328
+ exec_aten::Tensor& logits_tensor = logits_res.get ();
329
+
330
+ prev_token = cur_token;
331
+
280
332
long sample_start_time_ms = util::time_in_ms ();
281
333
switch (logits_tensor.scalar_type ()) {
282
334
case ScalarType::Float: {
283
- next_tok = logitsToToken<float >(logits_tensor, pos, 0 );
335
+ cur_token = logitsToToken<float >(logits_tensor, pos, 0 );
284
336
break ;
285
337
}
286
338
case ScalarType::Half: {
287
- next_tok = logitsToToken<exec_aten::Half>(logits_tensor, pos, 0 );
339
+ cur_token = logitsToToken<exec_aten::Half>(logits_tensor, pos, 0 );
288
340
break ;
289
341
}
290
342
default :
@@ -299,19 +351,12 @@ Error Runner::generate(
299
351
// advance the state machine
300
352
if (pos < num_prompt_tokens - 1 ) {
301
353
// prefill, force the next token to be the next prompt token
302
- next = prompt_tokens[pos + 1 ];
303
- } else {
304
- // otherwise sample the next token from the logits
305
- next = next_tok;
354
+ cur_token = prompt_tokens[pos + 1 ];
306
355
}
307
- // ET_LOG(Info, "Output saved, next = %d", next);
308
356
pos++;
309
- if (use_kv_cache_) {
310
- pos_data.at (0 ) = pos;
311
- }
312
357
313
358
// print the token as string, decode it with the Tokenizer object
314
- auto piece_res = tokenizer_->decode (token, next );
359
+ auto piece_res = tokenizer_->decode (prev_token, cur_token );
315
360
ET_CHECK (piece_res.ok ());
316
361
const char * piece = piece_res.get ();
317
362
@@ -328,22 +373,20 @@ Error Runner::generate(
328
373
}
329
374
330
375
// data-dependent terminating condition: we have n_eos_ number of EOS
331
- if (pos >= num_prompt_tokens && next == eos_id_) {
376
+ if (pos >= num_prompt_tokens && cur_token == eos_id_) {
332
377
printf (" \n " );
333
378
ET_LOG (Info, " \n Reached to the end of generation" );
334
379
break ;
335
380
}
336
-
337
- token = next;
338
381
}
339
382
timers_.inference_end_ms = util::time_in_ms ();
340
383
printf (" \n " );
341
384
342
- if (pos + 1 == seq_len) {
385
+ if (pos == seq_len) {
343
386
ET_LOG (Info, " Sequence length (%i tokens) reached!" , seq_len);
344
387
}
345
388
346
- timers_.printReport (num_prompt_tokens, ( pos + 1 ) - num_prompt_tokens);
389
+ timers_.printReport (num_prompt_tokens, pos - num_prompt_tokens);
347
390
348
391
delete[] prompt_tokens;
349
392
return Error::Ok;
0 commit comments