@@ -121,24 +121,6 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) {
121
121
return res;
122
122
}
123
123
124
- std::vector<exec_aten::SizesType> Runner::getKVCacheShape () {
125
- // shape: (n_layers, args.max_batch_size, args.max_seq_len, self.n_kv_heads,
126
- // self.head_dim)
127
- std::vector<std::string> methods = {
128
- " get_n_layers" ,
129
- " get_max_batch_size" ,
130
- " get_max_seq_len" ,
131
- " get_n_kv_heads" ,
132
- " get_head_dim" };
133
- std::vector<int64_t > default_values = {12 , 1 , 128 , 32 , 128 };
134
- std::vector<exec_aten::SizesType> result;
135
- for (int i = 0 ; i < methods.size (); ++i) {
136
- // convert from int64_t to int32_t
137
- result.push_back (getMetadataHelper<int64_t >(methods[i], default_values[i]));
138
- }
139
- return result;
140
- }
141
-
142
124
template <typename T>
143
125
int32_t Runner::logitsToToken (
144
126
const exec_aten::Tensor& logits_tensor,
@@ -155,6 +137,73 @@ int32_t Runner::logitsToToken(
155
137
return sampler_->sample (logits_last);
156
138
}
157
139
140
+ // Given an input token. Set up the inputs for the model and execute a single
141
+ // step. Returning the logits tensor.
142
+ Result<torch::executor::Tensor> Runner::run_model_step (
143
+ int64_t input_token,
144
+ ManagedTensor& managed_tokens,
145
+ ManagedTensor& managed_start_pos,
146
+ size_t max_seq_len) {
147
+ // ET_LOG(Info, "Input token %" PRIu64, input_token);
148
+ if (use_kv_cache_) {
149
+ std::vector<EValue> inputs;
150
+ auto tokens = managed_tokens.get_aliasing_tensor ();
151
+ auto start_pos = managed_start_pos.get_aliasing_tensor ();
152
+
153
+ // When using kv-cache our input is always 1 token, so just update to the
154
+ // latest.
155
+ tokens.mutable_data_ptr <int64_t >()[0 ] = input_token;
156
+
157
+ // inputs:[tokens, start_pos]
158
+ inputs.push_back (tokens);
159
+ inputs.push_back (start_pos);
160
+
161
+ Result<std::vector<EValue>> outputs_res = module_->forward (inputs);
162
+ ET_CHECK_OK_OR_RETURN_ERROR (outputs_res.error ());
163
+ ET_CHECK_MSG (
164
+ outputs_res.get ().size () == 1 ,
165
+ " More then one output returned from executing LLM." );
166
+ ET_CHECK_MSG (
167
+ outputs_res.get ()[0 ].isTensor (),
168
+ " Non Tensor Output returned from executing LLM" );
169
+
170
+ // Bump start_pos by 1
171
+ start_pos.mutable_data_ptr <int64_t >()[0 ]++;
172
+
173
+ // Return the logits tensor
174
+ return outputs_res.get ()[0 ].toTensor ();
175
+ } else { // no kv cache
176
+ std::vector<EValue> inputs;
177
+ auto tokens = managed_tokens.get_aliasing_tensor ();
178
+ auto start_pos = managed_start_pos.get_aliasing_tensor ();
179
+
180
+ // When not using kv-cache our input is the entire history of tokens we have
181
+ // seen, so resize input to be 1 larger and append the new token to the end.
182
+ // TODO does this work in ATen mode?
183
+ tokens.mutable_data_ptr <int64_t >()[tokens.size (1 ) - 1 ] = input_token;
184
+
185
+ // inputs:[tokens]
186
+ inputs.push_back (tokens);
187
+
188
+ Result<std::vector<EValue>> outputs_res = module_->forward (inputs);
189
+ ET_CHECK_OK_OR_RETURN_ERROR (outputs_res.error ());
190
+ ET_CHECK_MSG (
191
+ outputs_res.get ().size () == 1 ,
192
+ " More then one output returned from executing LLM." );
193
+ ET_CHECK_MSG (
194
+ outputs_res.get ()[0 ].isTensor (),
195
+ " Non Tensor Output returned from executing LLM" );
196
+
197
+ if (tokens.size (1 ) < max_seq_len) {
198
+ // Resize the tokens tensor to be 1 larger for next step.
199
+ managed_tokens.resize ({1 , static_cast <int >(tokens.size (1 ) + 1 )});
200
+ }
201
+
202
+ // Return the logits tensor
203
+ return outputs_res.get ()[0 ].toTensor ();
204
+ }
205
+ }
206
+
158
207
Error Runner::generate (
159
208
const std::string& prompt,
160
209
int32_t seq_len,
@@ -189,9 +238,6 @@ Error Runner::generate(
189
238
prompt_tokens,
190
239
&num_prompt_tokens);
191
240
192
- for (int i = 0 ; i < num_prompt_tokens; i++) {
193
- ET_LOG (Info, " prompt_tokens[%d]: %d" , i, prompt_tokens[i]);
194
- }
195
241
ET_CHECK_MSG (num_prompt_tokens >= 1 , " Expected at least 1 prompt token" );
196
242
ET_CHECK_MSG (
197
243
num_prompt_tokens < max_seq_len_,
@@ -202,89 +248,68 @@ Error Runner::generate(
202
248
" Sequence length exceeded - please increase the seq_len value passed to generate()" );
203
249
204
250
// start the main loop
205
- int next; // will store the next token in the sequence
206
- int64_t pos = num_prompt_tokens - 1 ; // position in the sequence
207
- int token = prompt_tokens[pos]; // prefill starts from 0 to num_prompt_tokens
208
- int logits_index = 0 ; // index of the logits tensor in the output
209
- std::vector<exec_aten::SizesType> input_shape = {1 , 1 };
210
- std::vector<exec_aten::SizesType> pos_shape = {1 };
251
+ int64_t pos = 0 ; // position in the sequence
252
+
211
253
std::vector<int64_t > token_data; // allocate space for the tokens
212
- std::vector<int64_t > pos_data; // allocate space for the tokens
254
+ std::vector<exec_aten::SizesType> token_shape = {1 , seq_len};
255
+
256
+ std::vector<int64_t > start_pos_data; // allocate space for the tokens
257
+ std::vector<exec_aten::SizesType> start_pos_shape = {1 };
213
258
214
259
if (use_kv_cache_) {
215
- // set pos to 0, refill token by token
216
- pos = 0 ;
260
+ // hard code these to size 1 as kv cache is locked to static size right now.
217
261
token_data.resize (1 );
218
- pos_data.resize (seq_len);
262
+ token_shape[1 ] = 1 ;
263
+ start_pos_data.resize (1 );
264
+ start_pos_data[0 ] = 0 ;
219
265
} else {
220
- // reserve data for tokens, notice the size is still 0.
266
+ // reserve data for tokens, notice the size is still 0 but the capacity is
267
+ // seq_len.
221
268
token_data.resize (seq_len);
222
269
}
223
270
224
271
// initialize tensor wrappers
225
- ManagedTensor pos_managed (
226
- pos_data .data (), pos_data. size (), pos_shape, ScalarType::Long);
227
-
228
- // copy prompt tokens into data
229
- for ( int i = 0 ; i <= pos; ++i) {
230
- // @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
231
- token_data[i] = prompt_tokens[i];
232
- if (i > 0 ) {
233
- printf (
234
- " %s " ,
235
- ET_UNWRAP (
236
- tokenizer_-> decode (prompt_tokens[i - 1 ], prompt_tokens[i])) );
237
- }
238
- }
239
-
240
- // create a 1xN int tensor with next as value
241
- while (pos + 1 < seq_len) {
272
+ ManagedTensor tokens_managed (
273
+ token_data .data (),
274
+ 128 , // TODO clean up unused 128 here as ManagedTensor ignores this arg in
275
+ // ctor
276
+ token_shape,
277
+ ScalarType::Long);
278
+ // Create with the max shape to approapriately set the capacity of this
279
+ // tensor, then resize back to 1 for first input.
280
+ tokens_managed. resize ({ 1 , 1 });
281
+
282
+ ManagedTensor start_pos_managed (
283
+ start_pos_data. data (), 128 , start_pos_shape, ScalarType::Long );
284
+
285
+ int64_t prev_token = - 1 ;
286
+ int64_t cur_token = prompt_tokens[ 0 ];
287
+
288
+ while (pos < seq_len - 1 ) {
242
289
// ET_LOG(Info, "Generating step %d...", pos);
243
- // set the current token in the tensor
244
- std::vector<EValue> inputs;
245
- if (use_kv_cache_) {
246
- token_data[0 ] = token;
247
- input_shape[1 ] = 1 ;
248
- // inputs: [tokens, start_pos, k_cache, v_cache]
249
- inputs.emplace_back (pos_managed.get_aliasing_tensor ());
250
- } else {
251
- // @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
252
- token_data[pos] = token;
253
- input_shape[1 ] = pos + 1 ;
254
- }
255
- ManagedTensor token_managed (
256
- token_data.data (), token_data.size (), input_shape, ScalarType::Long);
257
- inputs.insert (inputs.begin (), token_managed.get_aliasing_tensor ());
258
- // For kv cache, inputs: [tokens, start_pos, k_cache, v_cache]
259
- // Otherwise inputs: [tokens]
260
- Result<std::vector<EValue>> outputs_res = module_->forward (inputs);
261
- ET_CHECK_MSG (
262
- outputs_res.ok (),
263
- " Execution of method forward failed with status 0x%" PRIx32,
264
- static_cast <int32_t >(outputs_res.error ()));
265
- // ET_LOG(Info, "Model executed successfully.");
266
290
267
- std::vector<EValue> outputs = outputs_res.get ();
268
- // Check the outputs.
269
- ET_CHECK_MSG (
270
- outputs.size () == 1 && outputs.at (0 ).isTensor (),
271
- " Expecting output to have exactly 1 tensor output. Got %zu outputs." ,
272
- outputs.size ());
291
+ Result<torch::executor::Tensor> logits_res =
292
+ run_model_step (cur_token, tokens_managed, start_pos_managed, seq_len);
293
+
273
294
if (pos == num_prompt_tokens) {
274
295
timers_.first_token_ms = util::time_in_ms ();
275
296
} else if (pos == num_prompt_tokens - 1 ) {
276
297
timers_.prompt_eval_end_ms = util::time_in_ms ();
277
298
}
278
- int32_t next_tok;
279
- exec_aten::Tensor logits_tensor = outputs.at (logits_index).toTensor ();
299
+
300
+ ET_CHECK_OK_OR_RETURN_ERROR (logits_res.error ());
301
+ exec_aten::Tensor& logits_tensor = logits_res.get ();
302
+
303
+ prev_token = cur_token;
304
+
280
305
long sample_start_time_ms = util::time_in_ms ();
281
306
switch (logits_tensor.scalar_type ()) {
282
307
case ScalarType::Float: {
283
- next_tok = logitsToToken<float >(logits_tensor, pos, 0 );
308
+ cur_token = logitsToToken<float >(logits_tensor, pos, 0 );
284
309
break ;
285
310
}
286
311
case ScalarType::Half: {
287
- next_tok = logitsToToken<exec_aten::Half>(logits_tensor, pos, 0 );
312
+ cur_token = logitsToToken<exec_aten::Half>(logits_tensor, pos, 0 );
288
313
break ;
289
314
}
290
315
default :
@@ -299,19 +324,12 @@ Error Runner::generate(
299
324
// advance the state machine
300
325
if (pos < num_prompt_tokens - 1 ) {
301
326
// prefill, force the next token to be the next prompt token
302
- next = prompt_tokens[pos + 1 ];
303
- } else {
304
- // otherwise sample the next token from the logits
305
- next = next_tok;
327
+ cur_token = prompt_tokens[pos + 1 ];
306
328
}
307
- // ET_LOG(Info, "Output saved, next = %d", next);
308
329
pos++;
309
- if (use_kv_cache_) {
310
- pos_data.at (0 ) = pos;
311
- }
312
330
313
331
// print the token as string, decode it with the Tokenizer object
314
- auto piece_res = tokenizer_->decode (token, next );
332
+ auto piece_res = tokenizer_->decode (prev_token, cur_token );
315
333
ET_CHECK (piece_res.ok ());
316
334
const char * piece = piece_res.get ();
317
335
@@ -328,22 +346,20 @@ Error Runner::generate(
328
346
}
329
347
330
348
// data-dependent terminating condition: we have n_eos_ number of EOS
331
- if (pos >= num_prompt_tokens && next == eos_id_) {
349
+ if (pos >= num_prompt_tokens && cur_token == eos_id_) {
332
350
printf (" \n " );
333
351
ET_LOG (Info, " \n Reached to the end of generation" );
334
352
break ;
335
353
}
336
-
337
- token = next;
338
354
}
339
355
timers_.inference_end_ms = util::time_in_ms ();
340
356
printf (" \n " );
341
357
342
- if (pos + 1 == seq_len) {
358
+ if (pos == seq_len) {
343
359
ET_LOG (Info, " Sequence length (%i tokens) reached!" , seq_len);
344
360
}
345
361
346
- timers_.printReport (num_prompt_tokens, ( pos + 1 ) - num_prompt_tokens);
362
+ timers_.printReport (num_prompt_tokens, pos - num_prompt_tokens);
347
363
348
364
delete[] prompt_tokens;
349
365
return Error::Ok;
0 commit comments