Skip to content

Commit 5161d70

Browse files
authored
Patch llama.py for internal build (#7239)
Patch llama.py for internal build (#7239) Summary: As title, minor changes so we can use buck to build.. ``` buck run mode/dev-nosan //executorch/examples/qualcomm/oss_scripts/llama3_2:llama_qnn -- --compile_only --ptq 16a4w --checkpoint /home/chenlai/local/models/consolidated.00.pth --params /home/chenlai/local/models/params.json --tokenizer_model /home/chenlai/local/models/tokenizer.model --prompt "Once" -m SM8650 --model_size 1B --model_mode kv 2>&1 | tee static_llama.log ``` Differential Revision: D66947240
1 parent f22d1a3 commit 5161d70

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

examples/qualcomm/oss_scripts/llama2/llama.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ def inference(args, pre_gen_pte=""):
492492
f"model_mode {args.model_mode} is not implemented yet."
493493
)
494494

495+
assert args.tokenizer_bin is not None, "Need tokenizer model for interence"
495496
runner_args = " ".join(
496497
[
497498
f"--model_path {pte_filename}.pte",
@@ -562,8 +563,7 @@ def post_process():
562563
print(f"Results[{idx}]:\n{output}")
563564

564565

565-
# flake8: noqa: C901
566-
if __name__ == "__main__":
566+
def main():
567567
parser = setup_common_args_and_variables()
568568
parser.add_argument(
569569
"-a",
@@ -597,7 +597,7 @@ def post_process():
597597
parser.add_argument(
598598
"--tokenizer_bin",
599599
help="Pass llama2 tokenizer binary.",
600-
required=True,
600+
required=False,
601601
type=str,
602602
)
603603

@@ -680,3 +680,8 @@ def post_process():
680680
conn.send(json.dumps({"Error": str(e)}))
681681
else:
682682
raise Exception(e)
683+
684+
685+
# flake8: noqa: C901
686+
if __name__ == "__main__":
687+
main()

examples/qualcomm/oss_scripts/llama3_2/llama.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type, sharding_type):
218218
] + [n.target]:
219219
n.meta[QCOM_QUANTIZED_IO] = sharding_type
220220

221-
def quantize(self, quant_dtype, custom_annotations=()):
221+
def quantize(self, quant_dtype, args, custom_annotations=()):
222222
self.quant_dtype = quant_dtype
223223
quantizer = make_quantizer(
224224
quant_dtype=quant_dtype,
@@ -386,7 +386,8 @@ def compile(args):
386386
if args.ptq != None:
387387
start_quantize_ts = time.time()
388388
single_llama.quantize(
389-
quant_dtype,
389+
quant_dtype=quant_dtype,
390+
args=args,
390391
custom_annotations=(
391392
custom_annotate_llama_last_conv_16a8w,
392393
matmul_annotate_func,
@@ -486,8 +487,7 @@ def post_process():
486487
logging.info(f"Results[{idx}]:\n{output}")
487488

488489

489-
# flake8: noqa: C901
490-
if __name__ == "__main__":
490+
def main():
491491
parser = setup_common_args_and_variables()
492492
parser.add_argument(
493493
"-a",
@@ -605,3 +605,8 @@ def post_process():
605605
conn.send(json.dumps({"Error": str(e)}))
606606
else:
607607
raise Exception(e)
608+
609+
610+
# flake8: noqa: C901
611+
if __name__ == "__main__":
612+
main()

0 commit comments

Comments
 (0)