-
Notifications
You must be signed in to change notification settings - Fork 608
Qualcomm AI Engine Direct - Refine Llama3 Tokenizer #4940
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,6 @@ | |
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import gc | ||
import json | ||
import os | ||
from multiprocessing.connection import Client | ||
|
@@ -15,18 +14,19 @@ | |
QcomChipset, | ||
) | ||
from executorch.backends.qualcomm.utils.utils import ( | ||
canonicalize_program, | ||
from_context_binary, | ||
generate_htp_compiler_spec, | ||
generate_qnn_executorch_compiler_spec, | ||
generate_qnn_executorch_option, | ||
) | ||
from executorch.examples.qualcomm.qaihub_scripts.utils.utils import ( | ||
gen_pte_from_ctx_bin, | ||
get_encoding, | ||
) | ||
from executorch.examples.qualcomm.utils import ( | ||
setup_common_args_and_variables, | ||
SimpleADB, | ||
) | ||
from executorch.exir.backend.backend_api import to_backend | ||
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass | ||
|
||
|
||
def main(args): | ||
|
@@ -55,45 +55,25 @@ def main(args): | |
is_from_context_binary=True, | ||
) | ||
|
||
pte_name = ( | ||
"qaihub_llama2_7b_prompt" | ||
if args.use_prompt_processor | ||
else "qaihub_llama2_7b_token" | ||
) | ||
if args.use_prompt_processor: | ||
pte_name = "qaihub_llama2_7b_prompt" | ||
last_shard_num_inputs = 4 | ||
last_shard_num_outputs = 513 | ||
else: | ||
pte_name = "qaihub_llama2_7b_token" | ||
last_shard_num_inputs = 516 | ||
last_shard_num_outputs = 513 | ||
|
||
if args.pre_gen_pte is None: | ||
# create custom operators as context loader | ||
bundle_programs = [ | ||
from_context_binary(f"{args.context_binaries}/{target}", f"ctx_loader_{i}") | ||
for i, target in enumerate(target_names) | ||
] | ||
# lower with QnnBackend | ||
lowered_modules = [ | ||
to_backend("QnnBackend", prog["edge_program"], compiler_specs) | ||
for prog in bundle_programs | ||
] | ||
# setup spill-fill buffer for relieving runtime memory usage | ||
canonicalize_program(lowered_modules) | ||
# export pte files | ||
pte_files = [] | ||
for i in range(len(target_names)): | ||
print(f"pte {i} generating...") | ||
memory_planning_pass = MemoryPlanningPass( | ||
memory_planning_algo="greedy", | ||
alloc_graph_input=False, | ||
alloc_graph_output=False, | ||
) | ||
pte_files.append(f"{args.artifact}/{pte_name}_{i}.pte") | ||
with open(pte_files[-1], "wb") as file: | ||
file.write( | ||
lowered_modules[0].buffer( | ||
extract_delegate_segments=True, | ||
memory_planning=memory_planning_pass, | ||
) | ||
) | ||
# gc for reducing host memory consuming | ||
bundle_programs.pop(0) | ||
lowered_modules.pop(0) | ||
gc.collect() | ||
pte_names = [f"{pte_name}_{i}" for i in range(len(target_names))] | ||
pte_files = gen_pte_from_ctx_bin( | ||
args.artifact, pte_names, compiler_specs, bundle_programs | ||
) | ||
else: | ||
pte_files = [f"{args.pre_gen_pte}/{pte_name}_{i}.pte" for i in range(4)] | ||
|
||
|
@@ -125,7 +105,16 @@ def get_logit_encoding(path_to_last_shard: str): | |
) | ||
output_file = "result.txt" | ||
pos_embs_file = ["freq_cos", "freq_sin"] | ||
scale, offset = get_logit_encoding(target_names[-1]) | ||
encoding = get_encoding( | ||
path_to_shard=f"{args.context_binaries}/{target_names[-1]}", | ||
compiler_specs=compiler_specs, | ||
get_input=False, | ||
get_output=True, | ||
num_input=last_shard_num_inputs, | ||
num_output=last_shard_num_outputs, | ||
)[0] | ||
scale = encoding["scale"][-1] | ||
offset = encoding["offset"][-1] | ||
outputs = [] | ||
runner_args = [ | ||
*[ | ||
|
@@ -173,7 +162,8 @@ def post_process(): | |
freq = (freq / scale + offset).clip(min=0, max=65535).detach() | ||
freq.to(dtype=torch.uint16).numpy().tofile(custom_files[-1]) | ||
|
||
adb.push(files=custom_files) | ||
if not args.skip_push: | ||
adb.push(files=custom_files) | ||
adb.execute(custom_runner_cmd=runner_cmds) | ||
adb.pull(args.artifact, callback=post_process) | ||
if args.ip and args.port != -1: | ||
|
@@ -230,7 +220,7 @@ def post_process(): | |
parser.add_argument( | ||
"--temperature", | ||
help="sampling temperature for llama2", | ||
default=0.8, | ||
default=0.0, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any specific reason we're using 0 temperature? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We change the default to 0 because the output can be more consistent, which is better for testing purposes. |
||
type=float, | ||
) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this. I reverted the change to use build-x86 instead and it seems like some cases are missing