Skip to content

Commit 47d50e9

Browse files
limintangfacebook-github-bot
authored andcommitted
Refactor and add Llama Python library build
Summary: As title. To use static llama export outside QC dir. Differential Revision: D68937637
1 parent 42f744a commit 47d50e9

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

examples/qualcomm/oss_scripts/llama/TARGETS

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,24 @@ python_library(
1515
],
1616
)
1717

18+
python_library(
19+
name = "llama_lib",
20+
srcs = ["llama.py"],
21+
deps = [
22+
"//caffe2:torch",
23+
"//executorch/backends/qualcomm/partition:partition",
24+
"//executorch/backends/qualcomm/quantizer:quantizer",
25+
"//executorch/devtools:lib",
26+
"//executorch/examples/models:models",
27+
"//executorch/examples/qualcomm/oss_scripts/llama:static_llama",
28+
"//executorch/examples/qualcomm:utils",
29+
"//executorch/extension/export_util:export_util",
30+
"//executorch/extension/llm/custom_ops:model_sharding_py",
31+
"//executorch/extension/llm/export:export_lib",
32+
"//executorch/extension/pybindings:aten_lib",
33+
],
34+
)
35+
1836
python_binary(
1937
name = "llama",
2038
srcs = ["llama.py"],

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ def post_process():
819819
logging.info(f"Results[{idx}]:\n{output}")
820820

821821

822-
def main():
822+
def _build_parser():
823823
parser = setup_common_args_and_variables()
824824
parser.add_argument(
825825
"-a",
@@ -944,7 +944,12 @@ def main():
944944
type=str,
945945
)
946946

947-
args = parser.parse_args()
947+
return parser
948+
949+
def main(args) -> None:
950+
parser = _build_parser()
951+
952+
args = parser.parse_args(args)
948953
if args.compile_only and args.pre_gen_pte:
949954
exit("Cannot set both compile_only and pre_gen_pte as true")
950955

@@ -1035,4 +1040,4 @@ def main():
10351040

10361041
# flake8: noqa: C901
10371042
if __name__ == "__main__":
1038-
main()
1043+
main(sys.argv[1:])

0 commit comments

Comments
 (0)