Skip to content

Commit 29faa2e

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Preserve modelname (#3122)
Summary: Pull Request resolved: #3122 Reviewed By: mikekgfb Differential Revision: D56212361 fbshipit-source-id: 877f2d3d8b2c078e21b0ababdfbc4e447cd86374
1 parent f2e660b commit 29faa2e

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

examples/models/llama2/builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def to_torch_dtype(self) -> torch.dtype:
6262

6363
def load_llama_model(
6464
*,
65+
modelname: str = "llama2",
6566
checkpoint: Optional[str] = None,
6667
checkpoint_dir: Optional[str] = None,
6768
params_path: str,
@@ -114,6 +115,7 @@ def load_llama_model(
114115

115116
return LlamaEdgeManager(
116117
model=model,
118+
modelname=modelname,
117119
weight_type=weight_type,
118120
dtype=dtype,
119121
use_kv_cache=use_kv_cache,
@@ -131,6 +133,7 @@ class LlamaEdgeManager:
131133
def __init__(
132134
self,
133135
model,
136+
modelname,
134137
weight_type,
135138
dtype,
136139
use_kv_cache,
@@ -139,6 +142,7 @@ def __init__(
139142
verbose: bool = False,
140143
):
141144
self.model = model
145+
self.modelname = modelname
142146
self.weight_type = weight_type
143147
self.dtype = dtype
144148
self.example_inputs = example_inputs

examples/models/llama2/export_llama_lib.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
485485
)
486486
params_path = canonical_path(args.params)
487487
output_dir_path = canonical_path(args.output_dir, dir=True)
488-
modelname = "llama2"
489488
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
490489

491490
# dtype override
@@ -552,6 +551,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
552551

553552
return (
554553
load_llama_model(
554+
modelname=modelname,
555555
checkpoint=checkpoint_path,
556556
checkpoint_dir=checkpoint_dir,
557557
params_path=params_path,
@@ -599,6 +599,8 @@ def _export_llama(modelname, args) -> str: # noqa: C901
599599
modelname, args
600600
).export_to_edge(quantizers)
601601

602+
modelname = builder_exported_to_edge.modelname
603+
602604
# to_backend
603605
partitioners = []
604606
if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:

0 commit comments

Comments
 (0)