Skip to content

Commit 79d8d24

Browse files
committed
Update on "refactor: Use llm_config instead of args in export_llama functions"
[ghstack-poisoned]
2 parents b9bfb11 + d6b20c1 commit 79d8d24

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -651,11 +651,19 @@ def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
651651
Returns a LLMEdgeManager prior to calling export_to_edge with quantizers
652652
"""
653653
# load model from checkpoint and params.json
654-
checkpoint_path = canonical_path(llm_config.base.checkpoint) if llm_config.base.checkpoint else None
654+
checkpoint_path = (
655+
canonical_path(llm_config.base.checkpoint)
656+
if llm_config.base.checkpoint
657+
else None
658+
)
655659
checkpoint_dir = (
656-
canonical_path(llm_config.base.checkpoint_dir) if llm_config.base.checkpoint_dir else None
660+
canonical_path(llm_config.base.checkpoint_dir)
661+
if llm_config.base.checkpoint_dir
662+
else None
663+
)
664+
params_path = (
665+
canonical_path(llm_config.base.params) if llm_config.base.params else None
657666
)
658-
params_path = canonical_path(llm_config.base.params) if llm_config.base.params else None
659667
output_dir_path = canonical_path(llm_config.export.output_dir, dir=True)
660668
weight_type = WeightType.FAIRSEQ2 if llm_config.base.fairseq2 else WeightType.LLAMA
661669

@@ -744,7 +752,7 @@ def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
744752
preq_mode=llm_config.base.preq_mode,
745753
preq_group_size=llm_config.base.preq_group_size,
746754
preq_embedding_quantize=llm_config.base.preq_embedding_quantize,
747-
local_global_attention=llm_config.model.local_global_attention
755+
local_global_attention=llm_config.model.local_global_attention,
748756
)
749757
)
750758

@@ -804,9 +812,9 @@ def _validate_args(llm_config):
804812
f"max_context_length {llm_config.export.max_context_length} must be >= max_seq_len {llm_config.export.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length."
805813
)
806814
if llm_config.model.enable_dynamic_shape and (
807-
llm_config.backend.coreml.enabled or
808-
llm_config.backend.mps.enabled or
809-
llm_config.backend.qnn.enabled
815+
llm_config.backend.coreml.enabled
816+
or llm_config.backend.mps.enabled
817+
or llm_config.backend.qnn.enabled
810818
):
811819
raise ValueError(
812820
"Dynamic shape is not supported with coreml, MPS or qnn backends."
@@ -1050,7 +1058,9 @@ def _to_edge_and_lower_llama( # noqa: C901
10501058
def _export_llama(llm_config, args) -> LLMEdgeManager: # noqa: C901
10511059
_validate_args(llm_config)
10521060

1053-
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(llm_config)
1061+
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
1062+
llm_config
1063+
)
10541064

10551065
additional_passes = []
10561066
if llm_config.base.model_class in TORCHTUNE_DEFINED_MODELS:

0 commit comments

Comments
 (0)