Skip to content

Commit e3706d3

Browse files
committed
MTK Android Llama Runner
1 parent ce67b54 commit e3706d3

File tree

3 files changed

+434
-0
lines changed

3 files changed

+434
-0
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#pragma once
2+
3+
namespace torch::executor {
4+
using llm_helper::LLMType;
5+
6+
// Sizes
7+
const size_t PROMPT_TOKEN_BATCH_SIZE = 128;
8+
const size_t CACHE_SIZE = 512;
9+
const size_t HIDDEN_SIZE = 4096;
10+
const size_t NUM_HEAD = 32;
11+
const size_t NUM_LAYER = 32;
12+
const size_t MAX_TOKEN_LENGTH = 8192;
13+
const double ROT_EMB_BASE = 500000;
14+
15+
// Types
16+
const LLMType MODEL_INPUT_TYPE = LLMType::FP32;
17+
const LLMType MODEL_OUTPUT_TYPE = LLMType::FP32;
18+
const LLMType CACHE_TYPE = LLMType::FP32;
19+
const LLMType MASK_TYPE = LLMType::FP32;
20+
const LLMType ROT_EMB_TYPE = LLMType::FP32;
21+
22+
// Paths
23+
const std::string TOKENIZER_PATH="/data/local/tmp/et-mtk/llama3/tokenizer.model";
24+
const std::string TOKEN_EMBEDDING_PATH="/data/local/tmp/et-mtk/llama3/embedding_llama3-8B-instruct_fp32.bin";
25+
26+
// Comma-Separated Paths
27+
const std::string PROMPT_MODEL_PATHS="/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_0.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_1.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_2.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_3.pte,";
28+
29+
// Comma-Separated Paths
30+
const std::string GEN_MODEL_PATHS="/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_0.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_1.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_2.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_3.pte,";
31+
32+
} // namespace torch::executor
Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* Copyright (c) 2024 MediaTek Inc.
4+
* All rights reserved.
5+
*
6+
* This source code is licensed under the BSD-style license found in the
7+
* LICENSE file in the root directory of this source tree.
8+
*/
9+
10+
/* Copyright Statement:
11+
*
12+
* This software/firmware and related documentation ("MediaTek Software") are
13+
* protected under relevant copyright laws. The information contained herein
14+
* is confidential and proprietary to MediaTek Inc. and/or its licensors.
15+
* Without the prior written permission of MediaTek inc. and/or its licensors,
16+
* any reproduction, modification, use or disclosure of MediaTek Software,
17+
* and information contained herein, in whole or in part, shall be strictly
18+
* prohibited.
19+
*/
20+
/* MediaTek Inc. (C) 2024. All rights reserved.
21+
*
22+
* BY OPENING THIS FILE, RECEIVER HEREBY UNEQUIVOCALLY ACKNOWLEDGES AND AGREES
23+
* THAT THE SOFTWARE/FIRMWARE AND ITS DOCUMENTATIONS ("MEDIATEK SOFTWARE")
24+
* RECEIVED FROM MEDIATEK AND/OR ITS REPRESENTATIVES ARE PROVIDED TO RECEIVER ON
25+
* AN "AS-IS" BASIS ONLY. MEDIATEK EXPRESSLY DISCLAIMS ANY AND ALL WARRANTIES,
26+
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF
27+
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE OR NONINFRINGEMENT.
28+
* NEITHER DOES MEDIATEK PROVIDE ANY WARRANTY WHATSOEVER WITH RESPECT TO THE
29+
* SOFTWARE OF ANY THIRD PARTY WHICH MAY BE USED BY, INCORPORATED IN, OR
30+
* SUPPLIED WITH THE MEDIATEK SOFTWARE, AND RECEIVER AGREES TO LOOK ONLY TO SUCH
31+
* THIRD PARTY FOR ANY WARRANTY CLAIM RELATING THERETO. RECEIVER EXPRESSLY
32+
* ACKNOWLEDGES THAT IT IS RECEIVER'S SOLE RESPONSIBILITY TO OBTAIN FROM ANY
33+
* THIRD PARTY ALL PROPER LICENSES CONTAINED IN MEDIATEK SOFTWARE. MEDIATEK
34+
* SHALL ALSO NOT BE RESPONSIBLE FOR ANY MEDIATEK SOFTWARE RELEASES MADE TO
35+
* RECEIVER'S SPECIFICATION OR TO CONFORM TO A PARTICULAR STANDARD OR OPEN
36+
* FORUM. RECEIVER'S SOLE AND EXCLUSIVE REMEDY AND MEDIATEK'S ENTIRE AND
37+
* CUMULATIVE LIABILITY WITH RESPECT TO THE MEDIATEK SOFTWARE RELEASED HEREUNDER
38+
* WILL BE, AT MEDIATEK'S OPTION, TO REVISE OR REPLACE THE MEDIATEK SOFTWARE AT
39+
* ISSUE, OR REFUND ANY SOFTWARE LICENSE FEES OR SERVICE CHARGE PAID BY RECEIVER
40+
* TO MEDIATEK FOR SUCH MEDIATEK SOFTWARE AT ISSUE.
41+
*
42+
* The following software/firmware and/or related documentation ("MediaTek
43+
* Software") have been modified by MediaTek Inc. All revisions are subject to
44+
* any receiver's applicable license agreements with MediaTek Inc.
45+
*/
46+
47+
#include "executorch/backends/mediatek/runtime/include/NeuronBufferAllocator.h"
48+
#include <executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
49+
50+
#include <ctime>
51+
#include <iostream>
52+
#include <memory>
53+
#include <random>
54+
55+
#include <executorch/extension/data_loader/file_data_loader.h>
56+
#include <executorch/extension/evalue_util/print_evalue.h>
57+
#include <executorch/runtime/executor/method.h>
58+
#include <executorch/runtime/executor/program.h>
59+
#include <executorch/runtime/platform/log.h>
60+
#include <executorch/runtime/platform/profiler.h>
61+
#include <executorch/runtime/platform/runtime.h>
62+
// #include <executorch/util/util.h>
63+
#include <executorch/extension/llm/runner/util.h>
64+
#include <executorch/runtime/core/result.h>
65+
66+
#include "llama_runner/ModelChunk.h"
67+
#include "llama_runner/Utils.h"
68+
#include "llama_runner/llm_helper/include/llm_types.h"
69+
#include "llama_runner/llm_helper/include/llama_runner_values.h"
70+
71+
static uint64_t MAX_RESPONSE = 50; // Maximum number of tokens to generate.
72+
// Global BOS and EOS option for tokenization (encoding)
73+
static constexpr int8_t kAddBos = 1;
74+
static constexpr int8_t kAddEos = 0;
75+
76+
using namespace torch::executor;
77+
using namespace torch::executor::llm_helper;
78+
using torch::executor::utils::Timer;
79+
80+
MTKLlamaRunner::MTKLlamaRunner(
81+
const std::string& model_path,
82+
const std::string& tokenizer_path,
83+
const float temperature)
84+
: modeloptions_(get_model_options()),
85+
modelpaths_(get_model_paths()) {
86+
runtime_init();
87+
ET_LOG(
88+
Info,
89+
"Creating MTK Llama runner. Current it will self-load .pte, .bin, and .so files. Initiated runtime_init().");
90+
}
91+
92+
Error MTKLlamaRunner::load() {
93+
if (is_loaded()) {
94+
return Error::Ok;
95+
}
96+
97+
// Load tokenizer
98+
ET_LOG(Info, "Loading tokenizer.");
99+
tokenizer_ = load_tokenizer();
100+
ET_LOG(Info, "Complete loading tokenizer.");
101+
102+
// Load prompt model
103+
runtime_ = std::make_unique<LlamaRuntime>();
104+
ET_LOG(Info, "Loading prompt model.");
105+
runtime_->Initialize(modeloptions_, modelpaths_);
106+
ET_LOG(Info, "Complete loading prompt model.");
107+
108+
return Error::Ok;
109+
}
110+
111+
bool MTKLlamaRunner::is_loaded() const {
112+
return tokenizer_ && runtime_;
113+
}
114+
115+
Error MTKLlamaRunner::generate(
116+
const std::string& prompt,
117+
int32_t seq_len,
118+
std::function<void(const std::string&)> token_callback,
119+
std::function<void(const Stats&)> stats_callback) {
120+
121+
if (!is_loaded()) {
122+
ET_CHECK_OK_OR_RETURN_ERROR(load());
123+
}
124+
125+
// Wrap the token_callback with print function
126+
std::function<void(const std::string&)> wrapped_callback =
127+
[token_callback](const std::string& piece) {
128+
util::safe_printf(piece.c_str());
129+
fflush(stdout);
130+
if (token_callback) {
131+
token_callback(piece);
132+
}
133+
};
134+
135+
ET_LOG(Info, "Starting inference from MTKLlamaRunner");
136+
inference(*runtime_.get(), tokenizer_, prompt, wrapped_callback);
137+
ET_LOG(Info, "Completed inference from MTKLlamaRunner");
138+
139+
return Error::Ok;
140+
}
141+
142+
void MTKLlamaRunner::stop() {
143+
if (is_loaded()) {
144+
runtime_->Release();
145+
} else {
146+
ET_LOG(Error, "Llama Runtime is not loaded, cannot stop");
147+
}
148+
}
149+
150+
LlamaModelOptions MTKLlamaRunner::get_model_options() {
151+
LlamaModelOptions options = {
152+
// Sizes
153+
.prompt_token_batch_size = PROMPT_TOKEN_BATCH_SIZE,
154+
.cache_size = CACHE_SIZE,
155+
.hidden_size = HIDDEN_SIZE,
156+
.num_head = NUM_HEAD,
157+
.num_layer = NUM_LAYER,
158+
.max_token_length = MAX_TOKEN_LENGTH,
159+
.rot_emb_base = ROT_EMB_BASE,
160+
161+
// Types
162+
.model_input_type = MODEL_INPUT_TYPE,
163+
.model_output_type = MODEL_OUTPUT_TYPE,
164+
.cache_type = CACHE_TYPE,
165+
.mask_type = MASK_TYPE,
166+
.rot_emb_type = ROT_EMB_TYPE};
167+
ET_LOG(Info, "Completed get_model_options");
168+
return options;
169+
}
170+
171+
LlamaModelPaths MTKLlamaRunner::get_model_paths() {
172+
LlamaModelPaths model_paths = {
173+
.tokenizer_path = TOKENIZER_PATH,
174+
.token_embedding_path = TOKEN_EMBEDDING_PATH,
175+
.prompt_model_paths = utils::split(PROMPT_MODEL_PATHS, ','),
176+
.gen_model_paths = utils::split(GEN_MODEL_PATHS, ',')};
177+
ET_LOG(Info, "Completed get_model_paths");
178+
return model_paths;
179+
}
180+
181+
Result<uint64_t> MTKLlamaRunner::digest_prompt(
182+
LlamaRuntime& llama_runtime,
183+
const std::unique_ptr<Tokenizer>& tokenizer,
184+
const std::vector<uint64_t> input_tokens) {
185+
const auto input_token_count = input_tokens.size();
186+
const auto prompt_token_batch_size = llama_runtime.GetTokenBatchSize();
187+
size_t cur_token_index = 0;
188+
189+
Timer timer_digest_prompt([=](const auto elapsed_sec) {
190+
// Ideal prompt size is a multiple of prompt batch size
191+
const size_t ideal_prompt_size =
192+
std::ceil(float(input_token_count) / prompt_token_batch_size) *
193+
prompt_token_batch_size;
194+
ET_LOG(
195+
Info,
196+
"Done analyzing prompt in %f sec (%f tok/s)",
197+
elapsed_sec,
198+
(float)ideal_prompt_size / elapsed_sec);
199+
});
200+
201+
auto getNextTokens = [&]() {
202+
const size_t num_tok_remain = input_token_count - cur_token_index;
203+
const size_t remainder = num_tok_remain % prompt_token_batch_size;
204+
const size_t num_new_tokens =
205+
remainder ? remainder : prompt_token_batch_size;
206+
const auto start = cur_token_index;
207+
const auto end = start + num_new_tokens;
208+
return std::vector(
209+
input_tokens.begin() + start, input_tokens.begin() + end);
210+
};
211+
212+
void* logits;
213+
timer_digest_prompt.Start();
214+
while (cur_token_index < input_token_count) {
215+
const auto next_tokens = getNextTokens();
216+
ET_LOG(
217+
Debug,
218+
"Digest next tokens (size=%zu), 1st tok=%lu",
219+
next_tokens.size(),
220+
next_tokens[0]);
221+
logits = llama_runtime.Run(next_tokens);
222+
cur_token_index += next_tokens.size();
223+
}
224+
timer_digest_prompt.End();
225+
226+
const auto vocab_size = tokenizer->vocab_size();
227+
const auto logits_type = llama_runtime.GetModelOptions().model_output_type;
228+
const auto first_output_token =
229+
utils::argmax(logits_type, logits, vocab_size);
230+
return first_output_token;
231+
}
232+
233+
Error MTKLlamaRunner::gen_response(
234+
LlamaRuntime& llama_runtime,
235+
const std::unique_ptr<Tokenizer>& tokenizer,
236+
const uint64_t input_token,
237+
std::function<void(const std::string&)> token_callback) {
238+
Timer timer_model_swap(
239+
[](const auto elapsed_sec) { ET_LOG(Info, "Model swapped."); });
240+
241+
// Swap to gen mode
242+
timer_model_swap.Start();
243+
llama_runtime.SwapModel(1);
244+
timer_model_swap.End();
245+
246+
size_t gen_tok_count = 0;
247+
uint64_t prev_token = input_token;
248+
uint64_t output_token = input_token;
249+
250+
auto decode_res = tokenizer->decode(prev_token, output_token);
251+
ET_CHECK_OR_RETURN_ERROR(
252+
decode_res.ok(),
253+
InvalidState,
254+
"Tokenizer failed to decode first generated token: %lu",
255+
output_token);
256+
std::string full_response = std::move(decode_res.get());
257+
std::vector<uint64_t> full_response_tokens = {input_token};
258+
259+
const auto vocab_size = tokenizer->vocab_size();
260+
const auto logits_type = llama_runtime.GetModelOptions().model_output_type;
261+
262+
double gen_total_time_sec = 0;
263+
Timer timer_gen_token(
264+
[&](const auto elapsed_sec) { gen_total_time_sec += elapsed_sec; });
265+
266+
// Print first output token
267+
token_callback(full_response);
268+
269+
while (gen_tok_count++ < MAX_RESPONSE &&
270+
llama_runtime.GetTokenIndex() < modeloptions_.max_token_length) {
271+
timer_gen_token.Start();
272+
void* logits = llama_runtime.Run({output_token});
273+
timer_gen_token.End();
274+
275+
prev_token = output_token;
276+
output_token = utils::argmax(logits_type, logits, vocab_size);
277+
full_response_tokens.push_back(output_token);
278+
279+
// Stop when output is EOS
280+
if (output_token == tokenizer->eos_tok()) {
281+
token_callback("</eos>");
282+
break;
283+
}
284+
auto decode_res = tokenizer->decode(prev_token, output_token);
285+
ET_CHECK_OR_RETURN_ERROR(
286+
decode_res.ok(),
287+
InvalidState,
288+
"Tokenizer failed to decode generated token %lu",
289+
output_token);
290+
const std::string tok_str = std::move(decode_res.get());
291+
full_response += tok_str;
292+
token_callback(tok_str);
293+
}
294+
295+
std::cout << "\n\n[Generated Tokens]\n"
296+
<< utils::to_string(full_response_tokens) << std::endl;
297+
298+
ET_LOG(
299+
Info,
300+
"Token generation speed: %f tok/s",
301+
gen_tok_count / gen_total_time_sec);
302+
303+
return Error::Ok;
304+
}
305+
306+
Error MTKLlamaRunner::inference(
307+
LlamaRuntime& llama_runtime,
308+
const std::unique_ptr<Tokenizer>& tokenizer,
309+
const std::string& prompt,
310+
std::function<void(const std::string&)> token_callback) {
311+
// Tokenize input prompt
312+
auto encode_res = tokenizer->encode(prompt, kAddBos, kAddEos);
313+
ET_CHECK_OR_RETURN_ERROR(
314+
encode_res.ok(), InvalidState, "Tokenizer failed to encode prompt");
315+
const auto input_tokens = std::move(encode_res.get());
316+
317+
// Run prompt mode (pre-fill)
318+
auto prefill_res = digest_prompt(llama_runtime, tokenizer, input_tokens);
319+
ET_CHECK_OR_RETURN_ERROR(
320+
prefill_res.ok(), InvalidState, "Failed to digest prompt");
321+
const auto first_output_token = prefill_res.get();
322+
323+
// run generation mode (decoding)
324+
return gen_response(llama_runtime, tokenizer, first_output_token, token_callback);
325+
}
326+
327+
std::unique_ptr<Tokenizer> MTKLlamaRunner::load_tokenizer() {
328+
std::unique_ptr<Tokenizer> tokenizer;
329+
// Assumes that tokenizer type is Tiktoken
330+
tokenizer = torch::executor::get_tiktoken_for_llama();
331+
tokenizer->load(modelpaths_.tokenizer_path);
332+
return tokenizer;
333+
}

0 commit comments

Comments
 (0)