Skip to content

Commit 1b5b3ee

Browse files
limintangfacebook-github-bot
authored andcommitted
Refactor and add Llama Python library build (#8107)
Summary: As title. To use static llama export outside QC dir. Reviewed By: cccclai Differential Revision: D68937637
1 parent e8ee36c commit 1b5b3ee

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

examples/qualcomm/oss_scripts/llama/TARGETS

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,24 @@ python_library(
1515
],
1616
)
1717

18+
python_library(
19+
name = "llama_lib",
20+
srcs = ["llama.py"],
21+
deps = [
22+
"//caffe2:torch",
23+
"//executorch/backends/qualcomm/partition:partition",
24+
"//executorch/backends/qualcomm/quantizer:quantizer",
25+
"//executorch/devtools:lib",
26+
"//executorch/examples/models:models",
27+
"//executorch/examples/qualcomm/oss_scripts/llama:static_llama",
28+
"//executorch/examples/qualcomm:utils",
29+
"//executorch/extension/export_util:export_util",
30+
"//executorch/extension/llm/custom_ops:model_sharding_py",
31+
"//executorch/extension/llm/export:export_lib",
32+
"//executorch/extension/pybindings:aten_lib",
33+
],
34+
)
35+
1836
python_binary(
1937
name = "llama",
2038
srcs = ["llama.py"],

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ def post_process():
819819
logging.info(f"Results[{idx}]:\n{output}")
820820

821821

822-
def main():
822+
def _build_parser():
823823
parser = setup_common_args_and_variables()
824824
parser.add_argument(
825825
"-a",
@@ -944,7 +944,13 @@ def main():
944944
type=str,
945945
)
946946

947-
args = parser.parse_args()
947+
return parser
948+
949+
950+
def main(args) -> None:
951+
parser = _build_parser()
952+
953+
args = parser.parse_args(args)
948954
if args.compile_only and args.pre_gen_pte:
949955
exit("Cannot set both compile_only and pre_gen_pte as true")
950956

@@ -1035,4 +1041,4 @@ def main():
10351041

10361042
# flake8: noqa: C901
10371043
if __name__ == "__main__":
1038-
main()
1044+
main(sys.argv[1:])

0 commit comments

Comments
 (0)