Skip to content

Commit e603f4f

Browse files
cccclaifacebook-github-bot
authored andcommitted
allow dummy_llama2 script to take real checkpoint/params (#2588)
Summary: Pull Request resolved: #2588 In this way we can pass the checkpoint/params to the script ``` python3.10 dummy_llama2.py -b artifact/ -m SM8650 --checkpoint /home/chenlai/qualcomm/meta-llama-mldemos-examples/models/stories110M/stories110M.pt --params /home/chenlai/qualcomm/meta-llama-mldemos-examples/models/stories110M/params.json ``` Reviewed By: kirklandsign Differential Revision: D55172497 fbshipit-source-id: c34bba1d60d911380b10a2cf39b62892628b0a1b
1 parent c8f2d8d commit e603f4f

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

examples/qualcomm/scripts/dummy_llama2.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,34 @@ def create_device_inputs(example_inputs, use_kv_cache):
7272
default="8a8w",
7373
)
7474

75+
parser.add_argument(
76+
"--checkpoint",
77+
help="Pass llama2 checkpoint.",
78+
default=False,
79+
)
80+
81+
parser.add_argument(
82+
"--params",
83+
help="Pass llama2 params json file.",
84+
default=False,
85+
)
86+
7587
args = parser.parse_args()
7688

7789
# ensure the working directory exist.
7890
os.makedirs(args.artifact, exist_ok=True)
7991

80-
instance = Llama2Model(use_kv_cache=args.use_kv_cache)
92+
if args.params and args.checkpoint:
93+
instance = Llama2Model(
94+
use_kv_cache=args.use_kv_cache,
95+
checkpoint=args.checkpoint,
96+
params=args.params,
97+
)
98+
else:
99+
instance = Llama2Model(
100+
use_kv_cache=args.use_kv_cache,
101+
)
102+
81103
inputs, input_list = create_device_inputs(
82104
instance.get_example_inputs(), args.use_kv_cache
83105
)

0 commit comments

Comments
 (0)