22
22
#include < memory>
23
23
#include < sstream>
24
24
25
- namespace torch {
26
- namespace executor {
25
+ using executorch::aten::ScalarType;
26
+ using executorch::aten::SizesType;
27
+ using executorch::aten::Tensor;
28
+ using executorch::extension::from_blob;
29
+ using executorch::extension::Module;
30
+ using executorch::extension::TensorPtr;
31
+ using executorch::extension::llm::BPETokenizer;
32
+ using executorch::extension::llm::Sampler;
33
+ using executorch::extension::llm::time_in_ms;
34
+ using executorch::runtime::Error;
35
+ using executorch::runtime::EValue;
36
+ using executorch::runtime::MethodMeta;
37
+ using executorch::runtime::Result;
38
+ using executorch::runtime::TensorInfo;
39
+
40
+ // TODO: Remove this usage of an internal-only function.
41
+ using executorch::runtime::internal::set_tensor_data;
42
+
43
+ namespace example {
27
44
28
45
namespace {
29
- using namespace executorch ::extension;
30
46
static constexpr auto kTopp = 0 .9f ;
31
47
void printReport (const Runner::Stats& stats);
32
48
std::string statsToJsonString (const Runner::Stats& stats);
@@ -57,7 +73,7 @@ Error Runner::load() {
57
73
if (is_loaded ()) {
58
74
return Error::Ok;
59
75
}
60
- stats_.model_load_start_ms = util:: time_in_ms ();
76
+ stats_.model_load_start_ms = time_in_ms ();
61
77
ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
62
78
63
79
// Read out metadata from the model
@@ -97,7 +113,7 @@ Error Runner::load() {
97
113
temperature_,
98
114
kTopp ,
99
115
static_cast <unsigned long long >(std::time (nullptr )));
100
- stats_.model_load_end_ms = util:: time_in_ms ();
116
+ stats_.model_load_end_ms = time_in_ms ();
101
117
102
118
return Error::Ok;
103
119
}
@@ -125,7 +141,7 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) {
125
141
}
126
142
127
143
template <typename T>
128
- int32_t Runner::logitsToToken (const exec_aten:: Tensor& logits_tensor) {
144
+ int32_t Runner::logitsToToken (const Tensor& logits_tensor) {
129
145
T* logits = logits_tensor.mutable_data_ptr <T>();
130
146
131
147
// Since the logits are for all tokens, get the last token probabilities
@@ -135,7 +151,7 @@ int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) {
135
151
136
152
// Given an input token. Set up the inputs for the model and execute a single
137
153
// step. Returning the logits tensor.
138
- Result<exec_aten:: Tensor> Runner::run_model_step (
154
+ Result<Tensor> Runner::run_model_step (
139
155
int64_t input_token,
140
156
TensorPtr& token,
141
157
TensorPtr& start_pos,
@@ -167,7 +183,7 @@ Result<exec_aten::Tensor> Runner::run_model_step(
167
183
char * new_inp_addr = io_mem_mgr_.update_k_caches_read (j, el_size);
168
184
// inputs
169
185
ET_CHECK_MSG (
170
- internal:: set_tensor_data (
186
+ set_tensor_data (
171
187
*kv_tensors[j], new_inp_addr, kv_tensors[j]->nbytes ()) == Error::Ok,
172
188
" Failed to set input tensor when updating k_cache" );
173
189
}
@@ -177,13 +193,13 @@ Result<exec_aten::Tensor> Runner::run_model_step(
177
193
char * new_inp_addr = io_mem_mgr_.update_v_caches_read (v_idx, v_offset);
178
194
179
195
ET_CHECK_MSG (
180
- internal:: set_tensor_data (
196
+ set_tensor_data (
181
197
*kv_tensors[j], new_inp_addr, kv_tensors[j]->nbytes ()) == Error::Ok,
182
198
" Failed to set input tensor when updating v_cache" );
183
199
// outputs
184
200
char * new_out_addr = io_mem_mgr_.update_v_caches_write (v_idx, v_offset);
185
201
ET_CHECK_MSG (
186
- internal:: set_tensor_data (
202
+ set_tensor_data (
187
203
*kv_outputs[j], new_out_addr, kv_outputs[j]->nbytes ()) == Error::Ok,
188
204
" Failed to set output tensor when updating v_cache" );
189
205
ET_CHECK_MSG (
@@ -210,7 +226,7 @@ Error Runner::generate(
210
226
211
227
// First token time only measures the time it takes to encode the prompt and
212
228
// return a response token.
213
- stats_.inference_start_ms = util:: time_in_ms ();
229
+ stats_.inference_start_ms = time_in_ms ();
214
230
shouldStop_ = false ;
215
231
216
232
// Set the sequence length to the max seq length if not provided
@@ -235,21 +251,21 @@ Error Runner::generate(
235
251
" Sequence length exceeded - please increase the seq_len value passed to generate()" );
236
252
237
253
int32_t pos = 0 , prev_token, cur_token = prompt_tokens[0 ];
238
- std::vector<exec_aten:: SizesType> token_shape = {1 , 1 };
254
+ std::vector<SizesType> token_shape = {1 , 1 };
239
255
240
256
io_mem_mgr_.get_input_token_ptr ()[0 ] = 0 ;
241
- std::vector<exec_aten:: SizesType> start_pos_shape = {1 , 1 };
257
+ std::vector<SizesType> start_pos_shape = {1 , 1 };
242
258
243
259
float * atten_mask_ptr =
244
260
reinterpret_cast <float *>(io_mem_mgr_.get_atten_mask_ptr ());
245
261
std::fill (atten_mask_ptr, atten_mask_ptr + max_seq_len_, -255 );
246
262
atten_mask_ptr[max_seq_len_ - 1 ] = 0 ;
247
263
248
- std::vector<exec_aten:: SizesType> atten_mask_shape = {1 , max_seq_len_};
264
+ std::vector<SizesType> atten_mask_shape = {1 , max_seq_len_};
249
265
250
- std::vector<exec_aten:: SizesType> logits_data_shape = {1 , vocab_size_};
266
+ std::vector<SizesType> logits_data_shape = {1 , vocab_size_};
251
267
252
- std::vector<exec_aten:: SizesType> hidden_states_data_shape = {1 , 1 , dim_};
268
+ std::vector<SizesType> hidden_states_data_shape = {1 , 1 , dim_};
253
269
254
270
// initialize tensor wrappers
255
271
auto token = from_blob (
@@ -274,7 +290,7 @@ Error Runner::generate(
274
290
method_meta->input_tensor_meta (input_index);
275
291
276
292
auto tensor_shape = tensor_meta->sizes ();
277
- std::vector<exec_aten:: SizesType> sizes (
293
+ std::vector<SizesType> sizes (
278
294
tensor_shape.data (), tensor_shape.data () + tensor_shape.size ());
279
295
kv_tensors.emplace_back (from_blob (
280
296
io_mem_mgr_.get_k_caches_read_ptr (i),
@@ -284,7 +300,7 @@ Error Runner::generate(
284
300
// outpus
285
301
Result<TensorInfo> out_tensor_meta = method_meta->output_tensor_meta (i + 1 );
286
302
tensor_shape = out_tensor_meta->sizes ();
287
- sizes = std::vector<exec_aten:: SizesType>{
303
+ sizes = std::vector<SizesType>{
288
304
tensor_shape.data (), tensor_shape.data () + tensor_shape.size ()};
289
305
kv_outputs.emplace_back (from_blob (
290
306
io_mem_mgr_.get_k_caches_write_ptr (i),
@@ -303,7 +319,7 @@ Error Runner::generate(
303
319
Result<TensorInfo> tensor_meta =
304
320
method_meta->input_tensor_meta (input_index);
305
321
auto tensor_shape = tensor_meta->sizes ();
306
- std::vector<exec_aten:: SizesType> sizes (
322
+ std::vector<SizesType> sizes (
307
323
tensor_shape.data (), tensor_shape.data () + tensor_shape.size ());
308
324
309
325
kv_tensors.emplace_back (from_blob (
@@ -315,7 +331,7 @@ Error Runner::generate(
315
331
Result<TensorInfo> out_tensor_meta =
316
332
method_meta->output_tensor_meta (output_index);
317
333
tensor_shape = out_tensor_meta->sizes ();
318
- sizes = std::vector<exec_aten:: SizesType>{
334
+ sizes = std::vector<SizesType>{
319
335
tensor_shape.data (), tensor_shape.data () + tensor_shape.size ()};
320
336
321
337
kv_outputs.push_back (from_blob (
@@ -342,19 +358,18 @@ Error Runner::generate(
342
358
auto logits_res = run_model_step (
343
359
cur_token, token, start_pos, atten_mask, kv_tensors, kv_outputs);
344
360
if (pos == num_prompt_tokens) {
345
- stats_.first_token_ms = util:: time_in_ms ();
361
+ stats_.first_token_ms = time_in_ms ();
346
362
} else if (pos == num_prompt_tokens - 1 ) {
347
- stats_.prompt_eval_end_ms = util:: time_in_ms ();
363
+ stats_.prompt_eval_end_ms = time_in_ms ();
348
364
}
349
365
350
366
ET_CHECK_OK_OR_RETURN_ERROR (logits_res.error ());
351
- exec_aten:: Tensor& logits_tensor = logits_res.get ();
367
+ Tensor& logits_tensor = logits_res.get ();
352
368
prev_token = cur_token;
353
- long sample_start_time_ms = util:: time_in_ms ();
369
+ long sample_start_time_ms = time_in_ms ();
354
370
355
371
cur_token = logitsToToken<float >(logits_tensor);
356
- stats_.aggregate_sampling_time_ms +=
357
- util::time_in_ms () - sample_start_time_ms;
372
+ stats_.aggregate_sampling_time_ms += time_in_ms () - sample_start_time_ms;
358
373
359
374
// advance the state machine
360
375
if (pos < num_prompt_tokens - 1 ) {
@@ -381,7 +396,7 @@ Error Runner::generate(
381
396
break ;
382
397
}
383
398
}
384
- stats_.inference_end_ms = util:: time_in_ms ();
399
+ stats_.inference_end_ms = time_in_ms ();
385
400
386
401
if (pos == seq_len) {
387
402
ET_LOG (Info, " Sequence length (%i tokens) reached!" , seq_len);
@@ -650,5 +665,4 @@ template bool Runner::getMetadataHelper<bool>(
650
665
std::string method_name,
651
666
bool default_val);
652
667
653
- } // namespace executor
654
- } // namespace torch
668
+ } // namespace example
0 commit comments