|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * Copyright (c) 2024 MediaTek Inc. |
| 4 | + * All rights reserved. |
| 5 | + * |
| 6 | + * This source code is licensed under the BSD-style license found in the |
| 7 | + * LICENSE file in the root directory of this source tree. |
| 8 | + */ |
| 9 | + |
| 10 | +/* Copyright Statement: |
| 11 | + * |
| 12 | + * This software/firmware and related documentation ("MediaTek Software") are |
| 13 | + * protected under relevant copyright laws. The information contained herein |
| 14 | + * is confidential and proprietary to MediaTek Inc. and/or its licensors. |
| 15 | + * Without the prior written permission of MediaTek inc. and/or its licensors, |
| 16 | + * any reproduction, modification, use or disclosure of MediaTek Software, |
| 17 | + * and information contained herein, in whole or in part, shall be strictly |
| 18 | + * prohibited. |
| 19 | + */ |
| 20 | +/* MediaTek Inc. (C) 2024. All rights reserved. |
| 21 | + * |
| 22 | + * BY OPENING THIS FILE, RECEIVER HEREBY UNEQUIVOCALLY ACKNOWLEDGES AND AGREES |
| 23 | + * THAT THE SOFTWARE/FIRMWARE AND ITS DOCUMENTATIONS ("MEDIATEK SOFTWARE") |
| 24 | + * RECEIVED FROM MEDIATEK AND/OR ITS REPRESENTATIVES ARE PROVIDED TO RECEIVER ON |
| 25 | + * AN "AS-IS" BASIS ONLY. MEDIATEK EXPRESSLY DISCLAIMS ANY AND ALL WARRANTIES, |
| 26 | + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF |
| 27 | + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE OR NONINFRINGEMENT. |
| 28 | + * NEITHER DOES MEDIATEK PROVIDE ANY WARRANTY WHATSOEVER WITH RESPECT TO THE |
| 29 | + * SOFTWARE OF ANY THIRD PARTY WHICH MAY BE USED BY, INCORPORATED IN, OR |
| 30 | + * SUPPLIED WITH THE MEDIATEK SOFTWARE, AND RECEIVER AGREES TO LOOK ONLY TO SUCH |
| 31 | + * THIRD PARTY FOR ANY WARRANTY CLAIM RELATING THERETO. RECEIVER EXPRESSLY |
| 32 | + * ACKNOWLEDGES THAT IT IS RECEIVER'S SOLE RESPONSIBILITY TO OBTAIN FROM ANY |
| 33 | + * THIRD PARTY ALL PROPER LICENSES CONTAINED IN MEDIATEK SOFTWARE. MEDIATEK |
| 34 | + * SHALL ALSO NOT BE RESPONSIBLE FOR ANY MEDIATEK SOFTWARE RELEASES MADE TO |
| 35 | + * RECEIVER'S SPECIFICATION OR TO CONFORM TO A PARTICULAR STANDARD OR OPEN |
| 36 | + * FORUM. RECEIVER'S SOLE AND EXCLUSIVE REMEDY AND MEDIATEK'S ENTIRE AND |
| 37 | + * CUMULATIVE LIABILITY WITH RESPECT TO THE MEDIATEK SOFTWARE RELEASED HEREUNDER |
| 38 | + * WILL BE, AT MEDIATEK'S OPTION, TO REVISE OR REPLACE THE MEDIATEK SOFTWARE AT |
| 39 | + * ISSUE, OR REFUND ANY SOFTWARE LICENSE FEES OR SERVICE CHARGE PAID BY RECEIVER |
| 40 | + * TO MEDIATEK FOR SUCH MEDIATEK SOFTWARE AT ISSUE. |
| 41 | + * |
| 42 | + * The following software/firmware and/or related documentation ("MediaTek |
| 43 | + * Software") have been modified by MediaTek Inc. All revisions are subject to |
| 44 | + * any receiver's applicable license agreements with MediaTek Inc. |
| 45 | + */ |
| 46 | + |
| 47 | +#include "executorch/backends/mediatek/runtime/include/NeuronBufferAllocator.h" |
| 48 | +#include <executorch/examples/mediatek/executor_runner/mtk_llama_runner.h> |
| 49 | + |
| 50 | +#include <ctime> |
| 51 | +#include <iostream> |
| 52 | +#include <memory> |
| 53 | +#include <random> |
| 54 | + |
| 55 | +#include <executorch/extension/data_loader/file_data_loader.h> |
| 56 | +#include <executorch/extension/evalue_util/print_evalue.h> |
| 57 | +#include <executorch/runtime/executor/method.h> |
| 58 | +#include <executorch/runtime/executor/program.h> |
| 59 | +#include <executorch/runtime/platform/log.h> |
| 60 | +#include <executorch/runtime/platform/profiler.h> |
| 61 | +#include <executorch/runtime/platform/runtime.h> |
| 62 | +// #include <executorch/util/util.h> |
| 63 | +#include <executorch/extension/llm/runner/util.h> |
| 64 | +#include <executorch/runtime/core/result.h> |
| 65 | + |
| 66 | +#include "llama_runner/ModelChunk.h" |
| 67 | +#include "llama_runner/Utils.h" |
| 68 | +#include "llama_runner/llm_helper/include/llm_types.h" |
| 69 | +#include "llama_runner/llm_helper/include/llama_runner_values.h" |
| 70 | + |
| 71 | +static uint64_t MAX_RESPONSE = 50; // Maximum number of tokens to generate. |
| 72 | +// Global BOS and EOS option for tokenization (encoding) |
| 73 | +static constexpr int8_t kAddBos = 1; |
| 74 | +static constexpr int8_t kAddEos = 0; |
| 75 | + |
| 76 | +using namespace torch::executor; |
| 77 | +using namespace torch::executor::llm_helper; |
| 78 | +using torch::executor::utils::Timer; |
| 79 | + |
| 80 | +MTKLlamaRunner::MTKLlamaRunner( |
| 81 | + const std::string& model_path, |
| 82 | + const std::string& tokenizer_path, |
| 83 | + const float temperature) |
| 84 | + : modeloptions_(get_model_options()), |
| 85 | + modelpaths_(get_model_paths()) { |
| 86 | + runtime_init(); |
| 87 | + ET_LOG( |
| 88 | + Info, |
| 89 | + "Creating MTK Llama runner. Current it will self-load .pte, .bin, and .so files. Initiated runtime_init()."); |
| 90 | +} |
| 91 | + |
| 92 | +Error MTKLlamaRunner::load() { |
| 93 | + if (is_loaded()) { |
| 94 | + return Error::Ok; |
| 95 | + } |
| 96 | + |
| 97 | + // Load tokenizer |
| 98 | + ET_LOG(Info, "Loading tokenizer."); |
| 99 | + tokenizer_ = load_tokenizer(); |
| 100 | + ET_LOG(Info, "Complete loading tokenizer."); |
| 101 | + |
| 102 | + // Load prompt model |
| 103 | + runtime_ = std::make_unique<LlamaRuntime>(); |
| 104 | + ET_LOG(Info, "Loading prompt model."); |
| 105 | + runtime_->Initialize(modeloptions_, modelpaths_); |
| 106 | + ET_LOG(Info, "Complete loading prompt model."); |
| 107 | + |
| 108 | + return Error::Ok; |
| 109 | +} |
| 110 | + |
| 111 | +bool MTKLlamaRunner::is_loaded() const { |
| 112 | + return tokenizer_ && runtime_; |
| 113 | +} |
| 114 | + |
| 115 | +Error MTKLlamaRunner::generate( |
| 116 | + const std::string& prompt, |
| 117 | + int32_t seq_len, |
| 118 | + std::function<void(const std::string&)> token_callback, |
| 119 | + std::function<void(const Stats&)> stats_callback) { |
| 120 | + |
| 121 | + if (!is_loaded()) { |
| 122 | + ET_CHECK_OK_OR_RETURN_ERROR(load()); |
| 123 | + } |
| 124 | + |
| 125 | + // Wrap the token_callback with print function |
| 126 | + std::function<void(const std::string&)> wrapped_callback = |
| 127 | + [token_callback](const std::string& piece) { |
| 128 | + util::safe_printf(piece.c_str()); |
| 129 | + fflush(stdout); |
| 130 | + if (token_callback) { |
| 131 | + token_callback(piece); |
| 132 | + } |
| 133 | + }; |
| 134 | + |
| 135 | + ET_LOG(Info, "Starting inference from MTKLlamaRunner"); |
| 136 | + inference(*runtime_.get(), tokenizer_, prompt, wrapped_callback); |
| 137 | + ET_LOG(Info, "Completed inference from MTKLlamaRunner"); |
| 138 | + |
| 139 | + return Error::Ok; |
| 140 | +} |
| 141 | + |
| 142 | +void MTKLlamaRunner::stop() { |
| 143 | + if (is_loaded()) { |
| 144 | + runtime_->Release(); |
| 145 | + } else { |
| 146 | + ET_LOG(Error, "Llama Runtime is not loaded, cannot stop"); |
| 147 | + } |
| 148 | +} |
| 149 | + |
| 150 | +LlamaModelOptions MTKLlamaRunner::get_model_options() { |
| 151 | + LlamaModelOptions options = { |
| 152 | + // Sizes |
| 153 | + .prompt_token_batch_size = PROMPT_TOKEN_BATCH_SIZE, |
| 154 | + .cache_size = CACHE_SIZE, |
| 155 | + .hidden_size = HIDDEN_SIZE, |
| 156 | + .num_head = NUM_HEAD, |
| 157 | + .num_layer = NUM_LAYER, |
| 158 | + .max_token_length = MAX_TOKEN_LENGTH, |
| 159 | + .rot_emb_base = ROT_EMB_BASE, |
| 160 | + |
| 161 | + // Types |
| 162 | + .model_input_type = MODEL_INPUT_TYPE, |
| 163 | + .model_output_type = MODEL_OUTPUT_TYPE, |
| 164 | + .cache_type = CACHE_TYPE, |
| 165 | + .mask_type = MASK_TYPE, |
| 166 | + .rot_emb_type = ROT_EMB_TYPE}; |
| 167 | + ET_LOG(Info, "Completed get_model_options"); |
| 168 | + return options; |
| 169 | +} |
| 170 | + |
| 171 | +LlamaModelPaths MTKLlamaRunner::get_model_paths() { |
| 172 | + LlamaModelPaths model_paths = { |
| 173 | + .tokenizer_path = TOKENIZER_PATH, |
| 174 | + .token_embedding_path = TOKEN_EMBEDDING_PATH, |
| 175 | + .prompt_model_paths = utils::split(PROMPT_MODEL_PATHS, ','), |
| 176 | + .gen_model_paths = utils::split(GEN_MODEL_PATHS, ',')}; |
| 177 | + ET_LOG(Info, "Completed get_model_paths"); |
| 178 | + return model_paths; |
| 179 | +} |
| 180 | + |
| 181 | +Result<uint64_t> MTKLlamaRunner::digest_prompt( |
| 182 | + LlamaRuntime& llama_runtime, |
| 183 | + const std::unique_ptr<Tokenizer>& tokenizer, |
| 184 | + const std::vector<uint64_t> input_tokens) { |
| 185 | + const auto input_token_count = input_tokens.size(); |
| 186 | + const auto prompt_token_batch_size = llama_runtime.GetTokenBatchSize(); |
| 187 | + size_t cur_token_index = 0; |
| 188 | + |
| 189 | + Timer timer_digest_prompt([=](const auto elapsed_sec) { |
| 190 | + // Ideal prompt size is a multiple of prompt batch size |
| 191 | + const size_t ideal_prompt_size = |
| 192 | + std::ceil(float(input_token_count) / prompt_token_batch_size) * |
| 193 | + prompt_token_batch_size; |
| 194 | + ET_LOG( |
| 195 | + Info, |
| 196 | + "Done analyzing prompt in %f sec (%f tok/s)", |
| 197 | + elapsed_sec, |
| 198 | + (float)ideal_prompt_size / elapsed_sec); |
| 199 | + }); |
| 200 | + |
| 201 | + auto getNextTokens = [&]() { |
| 202 | + const size_t num_tok_remain = input_token_count - cur_token_index; |
| 203 | + const size_t remainder = num_tok_remain % prompt_token_batch_size; |
| 204 | + const size_t num_new_tokens = |
| 205 | + remainder ? remainder : prompt_token_batch_size; |
| 206 | + const auto start = cur_token_index; |
| 207 | + const auto end = start + num_new_tokens; |
| 208 | + return std::vector( |
| 209 | + input_tokens.begin() + start, input_tokens.begin() + end); |
| 210 | + }; |
| 211 | + |
| 212 | + void* logits; |
| 213 | + timer_digest_prompt.Start(); |
| 214 | + while (cur_token_index < input_token_count) { |
| 215 | + const auto next_tokens = getNextTokens(); |
| 216 | + ET_LOG( |
| 217 | + Debug, |
| 218 | + "Digest next tokens (size=%zu), 1st tok=%lu", |
| 219 | + next_tokens.size(), |
| 220 | + next_tokens[0]); |
| 221 | + logits = llama_runtime.Run(next_tokens); |
| 222 | + cur_token_index += next_tokens.size(); |
| 223 | + } |
| 224 | + timer_digest_prompt.End(); |
| 225 | + |
| 226 | + const auto vocab_size = tokenizer->vocab_size(); |
| 227 | + const auto logits_type = llama_runtime.GetModelOptions().model_output_type; |
| 228 | + const auto first_output_token = |
| 229 | + utils::argmax(logits_type, logits, vocab_size); |
| 230 | + return first_output_token; |
| 231 | +} |
| 232 | + |
| 233 | +Error MTKLlamaRunner::gen_response( |
| 234 | + LlamaRuntime& llama_runtime, |
| 235 | + const std::unique_ptr<Tokenizer>& tokenizer, |
| 236 | + const uint64_t input_token, |
| 237 | + std::function<void(const std::string&)> token_callback) { |
| 238 | + Timer timer_model_swap( |
| 239 | + [](const auto elapsed_sec) { ET_LOG(Info, "Model swapped."); }); |
| 240 | + |
| 241 | + // Swap to gen mode |
| 242 | + timer_model_swap.Start(); |
| 243 | + llama_runtime.SwapModel(1); |
| 244 | + timer_model_swap.End(); |
| 245 | + |
| 246 | + size_t gen_tok_count = 0; |
| 247 | + uint64_t prev_token = input_token; |
| 248 | + uint64_t output_token = input_token; |
| 249 | + |
| 250 | + auto decode_res = tokenizer->decode(prev_token, output_token); |
| 251 | + ET_CHECK_OR_RETURN_ERROR( |
| 252 | + decode_res.ok(), |
| 253 | + InvalidState, |
| 254 | + "Tokenizer failed to decode first generated token: %lu", |
| 255 | + output_token); |
| 256 | + std::string full_response = std::move(decode_res.get()); |
| 257 | + std::vector<uint64_t> full_response_tokens = {input_token}; |
| 258 | + |
| 259 | + const auto vocab_size = tokenizer->vocab_size(); |
| 260 | + const auto logits_type = llama_runtime.GetModelOptions().model_output_type; |
| 261 | + |
| 262 | + double gen_total_time_sec = 0; |
| 263 | + Timer timer_gen_token( |
| 264 | + [&](const auto elapsed_sec) { gen_total_time_sec += elapsed_sec; }); |
| 265 | + |
| 266 | + // Print first output token |
| 267 | + token_callback(full_response); |
| 268 | + |
| 269 | + while (gen_tok_count++ < MAX_RESPONSE && |
| 270 | + llama_runtime.GetTokenIndex() < modeloptions_.max_token_length) { |
| 271 | + timer_gen_token.Start(); |
| 272 | + void* logits = llama_runtime.Run({output_token}); |
| 273 | + timer_gen_token.End(); |
| 274 | + |
| 275 | + prev_token = output_token; |
| 276 | + output_token = utils::argmax(logits_type, logits, vocab_size); |
| 277 | + full_response_tokens.push_back(output_token); |
| 278 | + |
| 279 | + // Stop when output is EOS |
| 280 | + if (output_token == tokenizer->eos_tok()) { |
| 281 | + token_callback("</eos>"); |
| 282 | + break; |
| 283 | + } |
| 284 | + auto decode_res = tokenizer->decode(prev_token, output_token); |
| 285 | + ET_CHECK_OR_RETURN_ERROR( |
| 286 | + decode_res.ok(), |
| 287 | + InvalidState, |
| 288 | + "Tokenizer failed to decode generated token %lu", |
| 289 | + output_token); |
| 290 | + const std::string tok_str = std::move(decode_res.get()); |
| 291 | + full_response += tok_str; |
| 292 | + token_callback(tok_str); |
| 293 | + } |
| 294 | + |
| 295 | + std::cout << "\n\n[Generated Tokens]\n" |
| 296 | + << utils::to_string(full_response_tokens) << std::endl; |
| 297 | + |
| 298 | + ET_LOG( |
| 299 | + Info, |
| 300 | + "Token generation speed: %f tok/s", |
| 301 | + gen_tok_count / gen_total_time_sec); |
| 302 | + |
| 303 | + return Error::Ok; |
| 304 | +} |
| 305 | + |
| 306 | +Error MTKLlamaRunner::inference( |
| 307 | + LlamaRuntime& llama_runtime, |
| 308 | + const std::unique_ptr<Tokenizer>& tokenizer, |
| 309 | + const std::string& prompt, |
| 310 | + std::function<void(const std::string&)> token_callback) { |
| 311 | + // Tokenize input prompt |
| 312 | + auto encode_res = tokenizer->encode(prompt, kAddBos, kAddEos); |
| 313 | + ET_CHECK_OR_RETURN_ERROR( |
| 314 | + encode_res.ok(), InvalidState, "Tokenizer failed to encode prompt"); |
| 315 | + const auto input_tokens = std::move(encode_res.get()); |
| 316 | + |
| 317 | + // Run prompt mode (pre-fill) |
| 318 | + auto prefill_res = digest_prompt(llama_runtime, tokenizer, input_tokens); |
| 319 | + ET_CHECK_OR_RETURN_ERROR( |
| 320 | + prefill_res.ok(), InvalidState, "Failed to digest prompt"); |
| 321 | + const auto first_output_token = prefill_res.get(); |
| 322 | + |
| 323 | + // run generation mode (decoding) |
| 324 | + return gen_response(llama_runtime, tokenizer, first_output_token, token_callback); |
| 325 | +} |
| 326 | + |
| 327 | +std::unique_ptr<Tokenizer> MTKLlamaRunner::load_tokenizer() { |
| 328 | + std::unique_ptr<Tokenizer> tokenizer; |
| 329 | + // Assumes that tokenizer type is Tiktoken |
| 330 | + tokenizer = torch::executor::get_tiktoken_for_llama(); |
| 331 | + tokenizer->load(modelpaths_.tokenizer_path); |
| 332 | + return tokenizer; |
| 333 | +} |
0 commit comments