Skip to content

Commit a876626

Browse files
committed
Qualcomm AI Engine Direct - Refactor llama runner
Summary: - Refactored io_manager into five distinct components: - DecoderRunner: Module wrapper class. - PromptProcessor: Handles prompt processing using the decoder and key-value manager. - TokenGenerator: Generates tokens using the decoder and key-value manager. - KVManager: Manages key-value cache with kv_updater, including data buffer allocation, cache updates, and buffer updates in TensorImpl. - IBufferAlloc: Allocates data buffers from RPC memory or client buffer. - Support multi-turn use case. Validate on story llama - To simulate the scenario, I forced decode mode to generate 5 tokens each time. Tokens with random length are inserted after one round of prefill->decode finished.
1 parent c5dd476 commit a876626

25 files changed

+1988
-2367
lines changed

backends/qualcomm/runtime/SharedBuffer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ std::size_t std::hash<CustomMemTensorInfo>::operator()(
2222
hash_val ^= std::hash<size_t>()(info.pos);
2323
hash_val ^= std::hash<size_t>()(info.tensor_bytes);
2424
for (int i = 0; i < info.rank; ++i) {
25-
hash_val ^= info.shape[i];
25+
hash_val ^= std::hash<uint32_t>()(info.shape[i]);
2626
}
2727
hash_val ^= std::hash<uint32_t>()(info.rank);
2828
hash_val ^= std::hash<executorch::aten::ScalarType>()(info.dtype);

backends/qualcomm/runtime/backends/QnnBackendFactory.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ std::unique_ptr<BackendConfigParameters> QnnBackendFactory::Create(
7878
options->soc_info(),
7979
htp_options);
8080
backend_params->qnn_mem_manager_ptr_ = std::make_unique<QnnMemManager>(
81-
implementation, backend_params->qnn_context_ptr_.get());
81+
implementation,
82+
backend_params->qnn_context_ptr_.get(),
83+
options->log_level());
8284
backend_params->backend_init_state_ = BackendInitializeState::INITIALIZED;
8385
} break;
8486
case QnnExecuTorchBackendType::kGpuBackend:

backends/qualcomm/runtime/backends/QnnMemManager.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,12 @@ Error QnnMemManager::RegisterIonMem(
4747
}
4848
tensor_wrapper->SetMemHandle(handle);
4949
registered_map_.insert({handle, mem_ptr});
50-
QNN_EXECUTORCH_LOG_INFO(
51-
"Tensor %s is successfully registered to ION shared memory.",
52-
tensor_wrapper->GetName().c_str());
50+
if (log_level_ >= QnnExecuTorchLogLevel::kLogLevelInfo) {
51+
QNN_EXECUTORCH_LOG_INFO(
52+
"Tensor %s is successfully registered to ION shared memory.",
53+
tensor_wrapper->GetName().c_str());
54+
}
55+
5356
return Error::Ok;
5457
}
5558

@@ -92,9 +95,11 @@ Error QnnMemManager::RegisterCustomMem(
9295
}
9396
tensor_wrapper->SetMemHandle(handle);
9497
registered_map_.insert({handle, mem_ptr});
95-
QNN_EXECUTORCH_LOG_INFO(
96-
"Tensor %s is successfully registered to custom shared memory.",
97-
tensor_wrapper->GetName().c_str());
98+
if (log_level_ >= QnnExecuTorchLogLevel::kLogLevelInfo) {
99+
QNN_EXECUTORCH_LOG_INFO(
100+
"Tensor %s is successfully registered to custom shared memory.",
101+
tensor_wrapper->GetName().c_str());
102+
}
98103
return Error::Ok;
99104
}
100105

backends/qualcomm/runtime/backends/QnnMemManager.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ class QnnMemManager {
2121
public:
2222
explicit QnnMemManager(
2323
const QnnImplementation& implementation,
24-
QnnContext* context)
25-
: implementation_(implementation), context_(context) {}
24+
QnnContext* context,
25+
QnnExecuTorchLogLevel log_level)
26+
: implementation_(implementation),
27+
context_(context),
28+
log_level_(log_level) {}
2629
~QnnMemManager() {
2730
DeRegisterMem();
2831
}
@@ -63,6 +66,7 @@ class QnnMemManager {
6366

6467
const QnnImplementation& implementation_;
6568
QnnContext* context_;
69+
QnnExecuTorchLogLevel log_level_;
6670
std::unordered_map<Qnn_MemHandle_t, void*> registered_map_;
6771
std::unordered_map<CustomMemTensorInfo, void*> pre_registered_handles_;
6872
std::unordered_map<executorch::aten::ScalarType, Qnn_DataType_t>

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3463,7 +3463,7 @@ def test_llama3_2_1b(self):
34633463
if self.pre_gen_pte:
34643464
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
34653465

3466-
golden_start_with = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>"
3466+
golden_start_with = "<|start_header_id|>user<|end_header_id|>"
34673467
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
34683468
with Listener((self.ip, self.port)) as listener:
34693469
conn = listener.accept()

examples/qualcomm/oss_scripts/llama/CMakeLists.txt

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,18 @@ list(
2828
${CMAKE_CURRENT_LIST_DIR}/qnn_llama_runner.cpp
2929
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
3030
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
31-
${CMAKE_CURRENT_LIST_DIR}/runner/io_manager.cpp
32-
${CMAKE_CURRENT_LIST_DIR}/runner/io_manager.h
31+
${CMAKE_CURRENT_LIST_DIR}/runner/decoder_runner.cpp
32+
${CMAKE_CURRENT_LIST_DIR}/runner/decoder_runner.h
33+
${CMAKE_CURRENT_LIST_DIR}/runner/prompt_processor.cpp
34+
${CMAKE_CURRENT_LIST_DIR}/runner/prompt_processor.h
35+
${CMAKE_CURRENT_LIST_DIR}/runner/token_generator.cpp
36+
${CMAKE_CURRENT_LIST_DIR}/runner/token_generator.h
37+
${CMAKE_CURRENT_LIST_DIR}/runner/imem_alloc.h
38+
${CMAKE_CURRENT_LIST_DIR}/runner/client_mem.h
39+
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.cpp
40+
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.h
41+
${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.cpp
42+
${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.h
3343
)
3444

3545
list(
@@ -42,7 +52,7 @@ list(
4252
# build qnn llama runner
4353
add_executable(qnn_llama_runner ${_llama_runner__srcs})
4454
target_include_directories(
45-
qnn_llama_runner PUBLIC ${_common_include_directories}
55+
qnn_llama_runner PUBLIC ${_common_include_directories} ${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizers/include
4656
)
4757

4858
target_link_options_shared_lib(quantized_ops_lib)

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 22 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
403403
logging.info("Quantizing the model...")
404404
calibrate(
405405
self.get_example_inputs(self.llama_meta["get_use_kv_cache"]),
406-
args.prompt,
406+
args.prompt[0],
407407
fx_graph_module,
408408
tokenizer=tokenizer,
409409
ar_len=self.llama_meta["get_ar_len"],
@@ -828,7 +828,7 @@ def permute(w, heads):
828828
return quant_attrs
829829

830830

831-
def inference(args, quant_attrs, pte_filename, runtime_tokenizer_path, pre_gen_pte=""):
831+
def inference(args, pte_filename, runtime_tokenizer_path, pre_gen_pte=""):
832832
workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama"
833833

834834
if args.model_mode == "kv":
@@ -854,14 +854,13 @@ def post_process():
854854
outputs.append(f.read())
855855

856856
seq_len = args.max_seq_len
857+
multi_prompts = " ".join([f'--prompt "{prompt}"' for prompt in args.prompt])
857858
runner_args = " ".join(
858859
[
859-
f'--prompt "{args.prompt}"',
860+
multi_prompts,
860861
f"--eval_mode {eval_mode}",
861862
f"--temperature {args.temperature}",
862863
f"--system_prompt '{args.system_prompt}'",
863-
f"--logits_scale {quant_attrs['scale']}",
864-
f"--logits_offset {quant_attrs['zero_point']}",
865864
]
866865
)
867866

@@ -1004,9 +1003,10 @@ def _build_parser():
10041003

10051004
parser.add_argument(
10061005
"--prompt",
1007-
help="User prompts for llama.",
1006+
help="User prompts for Llama. When multiple prompts are entered, a multi-turn conversation will be initiated. Note that this feature is currently for testing purposes only.",
10081007
required=True,
10091008
type=str,
1009+
nargs="+",
10101010
)
10111011

10121012
parser.add_argument(
@@ -1090,7 +1090,7 @@ def _build_parser():
10901090

10911091
def export_llama(args) -> None:
10921092
if args.compile_only and args.pre_gen_pte:
1093-
exit("Cannot set both compile_only and pre_gen_pte as true")
1093+
raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true")
10941094

10951095
if args.model_mode == "kv":
10961096
pte_filename = "kv_llama_qnn"
@@ -1126,29 +1126,15 @@ def export_llama(args) -> None:
11261126
elif args.kv_updater == "shift_pointer":
11271127
args.kv_updater = shift_pointer_updater
11281128
else:
1129-
exit(f"Using an unkown kv update {args.kv_updater}")
1129+
raise RuntimeError(f"Using an unknown kv update {args.kv_updater}")
11301130

11311131
if args.pre_gen_pte:
1132-
quant_attrs = json.load(
1133-
open(f"{args.pre_gen_pte}/{pte_filename}_quant_attrs.txt")
1134-
)
1135-
inference(
1136-
args, quant_attrs, pte_filename, runtime_tokenizer_path, args.pre_gen_pte
1137-
)
1138-
exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}")
1132+
inference(args, pte_filename, runtime_tokenizer_path, args.pre_gen_pte)
1133+
print(f"Finish the running pre_gen_pte from {args.pre_gen_pte}")
1134+
return
11391135

11401136
if args.compile_only:
1141-
quant_attrs = compile(args, pte_filename, tokenizer)
1142-
if quant_attrs:
1143-
json.dump(
1144-
{
1145-
"scale": quant_attrs["scale"],
1146-
"zero_point": quant_attrs["zero_point"],
1147-
},
1148-
open(f"{args.artifact}/{pte_filename}_quant_attrs.txt", "w"),
1149-
)
1150-
else:
1151-
logging.warning("Quant attributes of the logit is None.")
1137+
compile(args, pte_filename, tokenizer)
11521138

11531139
if args.ip and args.port != -1:
11541140
pte_path = f"{args.artifact}/{pte_filename}.pte"
@@ -1161,24 +1147,18 @@ def export_llama(args) -> None:
11611147
}
11621148
)
11631149
)
1164-
exit(f"Finish compile_only and save to {args.artifact}")
1150+
print(f"Finish compile_only and save to {args.artifact}")
1151+
return
1152+
1153+
compile(args, pte_filename, tokenizer)
1154+
inference(args, pte_filename, runtime_tokenizer_path)
11651155

1156+
1157+
def main():
1158+
parser = _build_parser()
1159+
args = parser.parse_args()
11661160
try:
1167-
quant_attrs = compile(args, pte_filename, tokenizer)
1168-
if quant_attrs:
1169-
logging.info(
1170-
f"Logit scale: {quant_attrs['scale']}; Logit offset: {quant_attrs['zero_point']}"
1171-
)
1172-
json.dump(
1173-
{
1174-
"scale": quant_attrs["scale"],
1175-
"zero_point": quant_attrs["zero_point"],
1176-
},
1177-
open(f"{args.artifact}/{pte_filename}_quant_attrs.txt", "w"),
1178-
)
1179-
else:
1180-
logging.warning("Quant attributes of the logit is None.")
1181-
inference(args, quant_attrs, pte_filename, runtime_tokenizer_path)
1161+
export_llama(args)
11821162
except Exception as e:
11831163
if args.ip and args.port != -1:
11841164
with Client((args.ip, args.port)) as conn:
@@ -1187,12 +1167,6 @@ def export_llama(args) -> None:
11871167
raise Exception(e)
11881168

11891169

1190-
def main():
1191-
parser = _build_parser()
1192-
args = parser.parse_args()
1193-
export_llama(args)
1194-
1195-
11961170
# flake8: noqa: C901
11971171
if __name__ == "__main__":
11981172
main()

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ DEFINE_string(
3434
"inference_speed.txt",
3535
"Records inference speed. For CI purpose.");
3636
DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
37-
DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt.");
37+
DEFINE_string(
38+
prompt,
39+
"The answer to the ultimate question is",
40+
"User prompts for Llama. When multiple prompts are entered, a multi-turn conversation will be initiated. Note that this feature is currently for testing purposes only.");
3841
DEFINE_string(
3942
system_prompt,
4043
"",
@@ -49,10 +52,8 @@ DEFINE_int32(
4952
"Total number of tokens to generate (prompt + output).");
5053
DEFINE_int32(
5154
eval_mode,
52-
1,
55+
0,
5356
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv)");
54-
DEFINE_double(logits_scale, 0.0, "Logits scale");
55-
DEFINE_int32(logits_offset, 0, "Logits offset");
5657
DEFINE_string(
5758
kv_updater,
5859
"How to update kv cache. Choose between SmartMask and ShiftPointer",
@@ -72,20 +73,46 @@ std::vector<std::string> CollectPrompts(int argc, char** argv) {
7273
return prompts;
7374
}
7475

76+
std::string get_formatted_prompt(
77+
const std::string& prompt,
78+
const std::string& system_prompt,
79+
example::LlamaVersion llama_version) {
80+
std::string formatted_prompt;
81+
switch (llama_version) {
82+
case example::LlamaVersion::kLlama2:
83+
formatted_prompt.append(prompt);
84+
break;
85+
case example::LlamaVersion::kLlama3:
86+
if (!system_prompt.empty()) {
87+
formatted_prompt.append(
88+
"<|start_header_id|>system<|end_header_id|>\n\n");
89+
formatted_prompt.append(system_prompt);
90+
formatted_prompt.append("<|eot_id|>");
91+
}
92+
formatted_prompt.append("<|start_header_id|>user<|end_header_id|>\n\n");
93+
formatted_prompt.append(prompt);
94+
formatted_prompt.append(
95+
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
96+
break;
97+
default:
98+
ET_CHECK_MSG(false, "unsupported llama version");
99+
break;
100+
}
101+
return formatted_prompt;
102+
}
103+
75104
int main(int argc, char** argv) {
76105
std::vector<std::string> prompts = CollectPrompts(argc, argv);
77106
gflags::ParseCommandLineFlags(&argc, &argv, true);
78107
// create llama runner
79108
example::Runner runner(
80-
{FLAGS_model_path},
109+
FLAGS_model_path.c_str(),
81110
FLAGS_tokenizer_path.c_str(),
82111
FLAGS_performance_output_path.c_str(),
83-
FLAGS_logits_scale,
84-
FLAGS_logits_offset,
85112
FLAGS_temperature,
86113
FLAGS_eval_mode,
87-
FLAGS_kv_updater,
88-
FLAGS_num_iters);
114+
FLAGS_kv_updater);
115+
auto llama_version = runner.get_llama_version();
89116
std::vector<char> buf;
90117
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
91118
std::ofstream fout(FLAGS_output_path.c_str());
@@ -97,8 +124,10 @@ int main(int argc, char** argv) {
97124
// generate tokens & store inference output
98125
for (int i = 0; i < FLAGS_num_iters; i++) {
99126
for (const auto& prompt : prompts) {
100-
runner.generate(
101-
FLAGS_seq_len, prompt.c_str(), FLAGS_system_prompt.c_str(), callback);
127+
std::string formatted_prompt;
128+
formatted_prompt = get_formatted_prompt(
129+
prompt, FLAGS_system_prompt, llama_version.get());
130+
runner.generate(formatted_prompt.c_str(), FLAGS_seq_len, callback);
102131
}
103132
}
104133
fout.write(buf.data(), buf.size());
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) Qualcomm Innovation Center, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/imem_alloc.h>
12+
#include <vector>
13+
14+
namespace example {
15+
/**
16+
* @class ClientMem
17+
* @brief Final class for client buffer allocation, implementing IBufferAlloc
18+
* interface. Used for SHIFT_POINTER mode.
19+
*/
20+
class ClientMem final : public IMemAlloc {
21+
public:
22+
ClientMem(){};
23+
// Disable copy constructors, r-value referencing, etc
24+
ClientMem(const ClientMem&) = delete;
25+
ClientMem& operator=(const ClientMem&) = delete;
26+
ClientMem(ClientMem&&) = delete;
27+
ClientMem& operator=(ClientMem&&) = delete;
28+
virtual ~ClientMem(){};
29+
/**
30+
* @brief Allocate buffer of specified size with vector.
31+
* @param data_size Size of the data to allocate.
32+
* @return Pointer to the allocated buffer.
33+
*/
34+
std::byte* allocate(size_t data_size) override {
35+
allocated_buffers_.push_back(std::vector<std::byte>(data_size));
36+
return allocated_buffers_.back().data();
37+
};
38+
// Only used for SMART_MASK mode
39+
void add_memory_info(
40+
void* data_ptr,
41+
size_t data_size,
42+
executorch::runtime::TensorInfo tensor_info) override {};
43+
44+
private:
45+
std::vector<std::vector<std::byte>> allocated_buffers_;
46+
};
47+
48+
} // namespace example

0 commit comments

Comments
 (0)