@@ -73,17 +73,22 @@ static uint64_t MAX_RESPONSE = 50; // Maximum number of tokens to generate.
73
73
static constexpr int8_t kAddBos = 1 ;
74
74
static constexpr int8_t kAddEos = 0 ;
75
75
76
- using namespace torch ::executor;
77
- using namespace torch ::executor::llm_helper;
78
- using torch::executor::utils::Timer;
76
+ using namespace example ::llm_helper;
77
+ using example::utils::argmax;
78
+ using example::utils::split;
79
+ using example::utils::Timer;
80
+ using example::utils::to_string;
81
+ using namespace mtk ::vars;
82
+
83
+ namespace llm = ::executorch::extension::llm;
79
84
80
85
MTKLlamaRunner::MTKLlamaRunner (
81
86
const std::string& model_path,
82
87
const std::string& tokenizer_path,
83
88
const float temperature)
84
89
: modeloptions_(get_model_options()),
85
90
modelpaths_(get_model_paths()) {
86
- runtime_init ();
91
+ executorch::runtime:: runtime_init ();
87
92
ET_LOG (
88
93
Info,
89
94
" Creating MTK Llama runner. Current it will self-load .pte, .bin, and .so files. Initiated runtime_init()." );
@@ -125,7 +130,7 @@ Error MTKLlamaRunner::generate(
125
130
// Wrap the token_callback with print function
126
131
std::function<void (const std::string&)> wrapped_callback =
127
132
[token_callback](const std::string& piece) {
128
- util ::safe_printf (piece.c_str ());
133
+ llm ::safe_printf (piece.c_str ());
129
134
fflush (stdout);
130
135
if (token_callback) {
131
136
token_callback (piece);
@@ -172,8 +177,8 @@ LlamaModelPaths MTKLlamaRunner::get_model_paths() {
172
177
LlamaModelPaths model_paths = {
173
178
.tokenizer_path = TOKENIZER_PATH,
174
179
.token_embedding_path = TOKEN_EMBEDDING_PATH,
175
- .prompt_model_paths = utils:: split (PROMPT_MODEL_PATHS, ' ,' ),
176
- .gen_model_paths = utils:: split (GEN_MODEL_PATHS, ' ,' )};
180
+ .prompt_model_paths = split (PROMPT_MODEL_PATHS, ' ,' ),
181
+ .gen_model_paths = split (GEN_MODEL_PATHS, ' ,' )};
177
182
ET_LOG (Info, " Completed get_model_paths" );
178
183
return model_paths;
179
184
}
@@ -225,8 +230,7 @@ Result<uint64_t> MTKLlamaRunner::digest_prompt(
225
230
226
231
const auto vocab_size = tokenizer->vocab_size ();
227
232
const auto logits_type = llama_runtime.GetModelOptions ().model_output_type ;
228
- const auto first_output_token =
229
- utils::argmax (logits_type, logits, vocab_size);
233
+ const auto first_output_token = argmax (logits_type, logits, vocab_size);
230
234
return first_output_token;
231
235
}
232
236
@@ -273,7 +277,7 @@ Error MTKLlamaRunner::gen_response(
273
277
timer_gen_token.End ();
274
278
275
279
prev_token = output_token;
276
- output_token = utils:: argmax (logits_type, logits, vocab_size);
280
+ output_token = argmax (logits_type, logits, vocab_size);
277
281
full_response_tokens.push_back (output_token);
278
282
279
283
// Stop when output is EOS
@@ -293,7 +297,7 @@ Error MTKLlamaRunner::gen_response(
293
297
}
294
298
295
299
std::cout << " \n\n [Generated Tokens]\n "
296
- << utils:: to_string (full_response_tokens) << std::endl;
300
+ << to_string (full_response_tokens) << std::endl;
297
301
298
302
ET_LOG (
299
303
Info,
@@ -327,7 +331,7 @@ Error MTKLlamaRunner::inference(
327
331
std::unique_ptr<Tokenizer> MTKLlamaRunner::load_tokenizer () {
328
332
std::unique_ptr<Tokenizer> tokenizer;
329
333
// Assumes that tokenizer type is Tiktoken
330
- tokenizer = torch::executor ::get_tiktoken_for_llama ();
334
+ tokenizer = example ::get_tiktoken_for_llama ();
331
335
tokenizer->load (modelpaths_.tokenizer_path );
332
336
return tokenizer;
333
337
}
0 commit comments