Skip to content

Commit e2fab02

Browse files
committed
Use llm_config instead of args in export_llama functions
Pull Request resolved: #11162 @imported-using-ghimport Differential Revision: [D75484927](https://our.internmc.facebook.com/intern/diff/D75484927/) ghstack-source-id: 288486592
1 parent 84e1fda commit e2fab02

File tree

11 files changed

+476
-391
lines changed

11 files changed

+476
-391
lines changed

backends/arm/test/models/test_llama.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
TosaPipelineMI,
2323
)
2424

25+
from executorch.examples.models.llama.config.llm_config_utils import (
26+
convert_args_to_llm_config,
27+
)
2528
from executorch.examples.models.llama.export_llama_lib import (
2629
build_args_parser,
2730
get_llama_model,
@@ -89,8 +92,9 @@ def prepare_model(self):
8992
]
9093
parser = build_args_parser()
9194
args = parser.parse_args(args)
95+
llm_config = convert_args_to_llm_config(args)
9296

93-
llama_model, llama_inputs, llama_meta = get_llama_model(args)
97+
llama_model, llama_inputs, llama_meta = get_llama_model(llm_config)
9498

9599
return llama_model, llama_inputs, llama_meta
96100

examples/apple/mps/scripts/mps_example.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
serialize_from_bundled_program_to_flatbuffer,
2121
)
2222

23+
from executorch.examples.models.llama.config.llm_config import LlmConfig
2324
from executorch.exir import (
2425
EdgeCompileConfig,
2526
EdgeProgramManager,
@@ -131,28 +132,24 @@ def parse_args():
131132
return args
132133

133134

134-
def get_model_config(args):
135-
model_config = {}
136-
model_config["module_name"] = MODEL_NAME_TO_MODEL[args.model_name][0]
137-
model_config["model_class_name"] = MODEL_NAME_TO_MODEL[args.model_name][1]
138-
139-
if args.model_name == "llama2":
140-
if args.checkpoint:
141-
model_config["checkpoint"] = args.checkpoint
142-
if args.params:
143-
model_config["params"] = args.params
144-
model_config["use_kv_cache"] = True
145-
return model_config
146-
147-
148135
if __name__ == "__main__":
149136
args = parse_args()
150137

151138
if args.model_name not in MODEL_NAME_TO_MODEL:
152139
raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.")
153140

154-
model_config = get_model_config(args)
155-
model, example_inputs, _, _ = EagerModelFactory.create_model(**model_config)
141+
llm_config = LlmConfig()
142+
if args.model_name == "llama2":
143+
if args.checkpoint:
144+
llm_config.base.checkpoint = args.checkpoint
145+
if args.params:
146+
llm_config.base.params = args.params
147+
llm_config.model.use_kv_cache = True
148+
model, example_inputs, _, _ = EagerModelFactory.create_model(
149+
module_name=MODEL_NAME_TO_MODEL[args.model_name][0],
150+
model_class_name=MODEL_NAME_TO_MODEL[args.model_name][1],
151+
llm_config=llm_config,
152+
)
156153

157154
model = model.eval()
158155

examples/models/llama/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ runtime.python_library(
6767
"//caffe2:torch",
6868
"//executorch/examples/models:model_base",
6969
"//executorch/examples/models/llama:llama_transformer",
70+
"//executorch/examples/models/llama/config:llm_config",
7071
"//executorch/examples/models:checkpoint",
7172
],
7273
)
@@ -266,6 +267,7 @@ runtime.python_library(
266267
":export_library",
267268
"//executorch/examples/models/llama/config:llm_config",
268269
"fbsource//third-party/pypi/hydra-core:hydra-core",
270+
"fbsource//third-party/pypi/omegaconf:omegaconf",
269271
],
270272
)
271273

examples/models/llama/config/llm_config_utils.py

Lines changed: 112 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -26,92 +26,143 @@ def convert_args_to_llm_config(args: argparse.Namespace) -> LlmConfig:
2626
llm_config = LlmConfig()
2727

2828
# BaseConfig
29-
llm_config.base.model_class = ModelType(args.model)
30-
llm_config.base.params = args.params
31-
llm_config.base.checkpoint = args.checkpoint
32-
llm_config.base.checkpoint_dir = args.checkpoint_dir
33-
llm_config.base.tokenizer_path = args.tokenizer_path
34-
llm_config.base.metadata = args.metadata
35-
llm_config.base.use_lora = bool(args.use_lora)
36-
llm_config.base.fairseq2 = args.fairseq2
29+
if hasattr(args, "model"):
30+
llm_config.base.model_class = ModelType(args.model)
31+
if hasattr(args, "params"):
32+
llm_config.base.params = args.params
33+
if hasattr(args, "checkpoint"):
34+
llm_config.base.checkpoint = args.checkpoint
35+
if hasattr(args, "checkpoint_dir"):
36+
llm_config.base.checkpoint_dir = args.checkpoint_dir
37+
if hasattr(args, "tokenizer_path"):
38+
llm_config.base.tokenizer_path = args.tokenizer_path
39+
if hasattr(args, "metadata"):
40+
llm_config.base.metadata = args.metadata
41+
if hasattr(args, "use_lora"):
42+
llm_config.base.use_lora = args.use_lora
43+
if hasattr(args, "fairseq2"):
44+
llm_config.base.fairseq2 = args.fairseq2
3745

3846
# PreqMode settings
39-
if args.preq_mode:
47+
if hasattr(args, "preq_mode") and args.preq_mode:
4048
llm_config.base.preq_mode = PreqMode(args.preq_mode)
41-
llm_config.base.preq_group_size = args.preq_group_size
42-
llm_config.base.preq_embedding_quantize = args.preq_embedding_quantize
49+
if hasattr(args, "preq_group_size"):
50+
llm_config.base.preq_group_size = args.preq_group_size
51+
if hasattr(args, "preq_embedding_quantize"):
52+
llm_config.base.preq_embedding_quantize = args.preq_embedding_quantize
4353

4454
# ModelConfig
45-
llm_config.model.dtype_override = DtypeOverride(args.dtype_override)
46-
llm_config.model.enable_dynamic_shape = args.enable_dynamic_shape
47-
llm_config.model.use_shared_embedding = args.use_shared_embedding
48-
llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache
49-
llm_config.model.expand_rope_table = args.expand_rope_table
50-
llm_config.model.use_attention_sink = args.use_attention_sink
51-
llm_config.model.output_prune_map = args.output_prune_map
52-
llm_config.model.input_prune_map = args.input_prune_map
53-
llm_config.model.use_kv_cache = args.use_kv_cache
54-
llm_config.model.quantize_kv_cache = args.quantize_kv_cache
55-
llm_config.model.local_global_attention = args.local_global_attention
55+
if hasattr(args, "dtype_override"):
56+
llm_config.model.dtype_override = DtypeOverride(args.dtype_override)
57+
if hasattr(args, "enable_dynamic_shape"):
58+
llm_config.model.enable_dynamic_shape = args.enable_dynamic_shape
59+
if hasattr(args, "use_shared_embedding"):
60+
llm_config.model.use_shared_embedding = args.use_shared_embedding
61+
if hasattr(args, "use_sdpa_with_kv_cache"):
62+
llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache
63+
if hasattr(args, "expand_rope_table"):
64+
llm_config.model.expand_rope_table = args.expand_rope_table
65+
if hasattr(args, "use_attention_sink"):
66+
llm_config.model.use_attention_sink = args.use_attention_sink
67+
if hasattr(args, "output_prune_map"):
68+
llm_config.model.output_prune_map = args.output_prune_map
69+
if hasattr(args, "input_prune_map"):
70+
llm_config.model.input_prune_map = args.input_prune_map
71+
if hasattr(args, "use_kv_cache"):
72+
llm_config.model.use_kv_cache = args.use_kv_cache
73+
if hasattr(args, "quantize_kv_cache"):
74+
llm_config.model.quantize_kv_cache = args.quantize_kv_cache
75+
if hasattr(args, "local_global_attention"):
76+
llm_config.model.local_global_attention = args.local_global_attention
5677

5778
# ExportConfig
58-
llm_config.export.max_seq_length = args.max_seq_length
59-
llm_config.export.max_context_length = args.max_context_length
60-
llm_config.export.output_dir = args.output_dir
61-
llm_config.export.output_name = args.output_name
62-
llm_config.export.so_library = args.so_library
63-
llm_config.export.export_only = args.export_only
79+
if hasattr(args, "max_seq_length"):
80+
llm_config.export.max_seq_length = args.max_seq_length
81+
if hasattr(args, "max_context_length"):
82+
llm_config.export.max_context_length = args.max_context_length
83+
if hasattr(args, "output_dir"):
84+
llm_config.export.output_dir = args.output_dir
85+
if hasattr(args, "output_name"):
86+
llm_config.export.output_name = args.output_name
87+
if hasattr(args, "so_library"):
88+
llm_config.export.so_library = args.so_library
89+
if hasattr(args, "export_only"):
90+
llm_config.export.export_only = args.export_only
6491

6592
# QuantizationConfig
66-
llm_config.quantization.qmode = args.quantization_mode
67-
llm_config.quantization.embedding_quantize = args.embedding_quantize
68-
if args.pt2e_quantize:
93+
if hasattr(args, "quantization_mode"):
94+
llm_config.quantization.qmode = args.quantization_mode
95+
if hasattr(args, "embedding_quantize"):
96+
llm_config.quantization.embedding_quantize = args.embedding_quantize
97+
if hasattr(args, "pt2e_quantize") and args.pt2e_quantize:
6998
llm_config.quantization.pt2e_quantize = Pt2eQuantize(args.pt2e_quantize)
70-
llm_config.quantization.group_size = args.group_size
71-
if args.use_spin_quant:
99+
if hasattr(args, "group_size"):
100+
llm_config.quantization.group_size = args.group_size
101+
if hasattr(args, "use_spin_quant") and args.use_spin_quant:
72102
llm_config.quantization.use_spin_quant = SpinQuant(args.use_spin_quant)
73-
llm_config.quantization.use_qat = args.use_qat
74-
llm_config.quantization.calibration_tasks = args.calibration_tasks
75-
llm_config.quantization.calibration_limit = args.calibration_limit
76-
llm_config.quantization.calibration_seq_length = args.calibration_seq_length
77-
llm_config.quantization.calibration_data = args.calibration_data
78-
79-
# BackendConfig
80-
# XNNPack
81-
llm_config.backend.xnnpack.enabled = args.xnnpack
82-
llm_config.backend.xnnpack.extended_ops = args.xnnpack_extended_ops
103+
if hasattr(args, "use_qat"):
104+
llm_config.quantization.use_qat = args.use_qat
105+
if hasattr(args, "calibration_tasks"):
106+
llm_config.quantization.calibration_tasks = args.calibration_tasks
107+
if hasattr(args, "calibration_limit"):
108+
llm_config.quantization.calibration_limit = args.calibration_limit
109+
if hasattr(args, "calibration_seq_length"):
110+
llm_config.quantization.calibration_seq_length = args.calibration_seq_length
111+
if hasattr(args, "calibration_data"):
112+
llm_config.quantization.calibration_data = args.calibration_data
113+
114+
# BackendConfig - XNNPack
115+
if hasattr(args, "xnnpack"):
116+
llm_config.backend.xnnpack.enabled = args.xnnpack
117+
if hasattr(args, "xnnpack_extended_ops"):
118+
llm_config.backend.xnnpack.extended_ops = args.xnnpack_extended_ops
83119

84120
# CoreML
85-
llm_config.backend.coreml.enabled = args.coreml
121+
if hasattr(args, "coreml"):
122+
llm_config.backend.coreml.enabled = args.coreml
86123
llm_config.backend.coreml.enable_state = getattr(args, "coreml_enable_state", False)
87124
llm_config.backend.coreml.preserve_sdpa = getattr(
88125
args, "coreml_preserve_sdpa", False
89126
)
90-
if args.coreml_quantize:
127+
if hasattr(args, "coreml_quantize") and args.coreml_quantize:
91128
llm_config.backend.coreml.quantize = CoreMLQuantize(args.coreml_quantize)
92-
llm_config.backend.coreml.ios = args.coreml_ios
93-
llm_config.backend.coreml.compute_units = CoreMLComputeUnit(
94-
args.coreml_compute_units
95-
)
129+
if hasattr(args, "coreml_ios"):
130+
llm_config.backend.coreml.ios = args.coreml_ios
131+
if hasattr(args, "coreml_compute_units"):
132+
llm_config.backend.coreml.compute_units = CoreMLComputeUnit(
133+
args.coreml_compute_units
134+
)
96135

97136
# Vulkan
98-
llm_config.backend.vulkan.enabled = args.vulkan
137+
if hasattr(args, "vulkan"):
138+
llm_config.backend.vulkan.enabled = args.vulkan
99139

100140
# QNN
101-
llm_config.backend.qnn.enabled = args.qnn
102-
llm_config.backend.qnn.use_sha = args.use_qnn_sha
103-
llm_config.backend.qnn.soc_model = args.soc_model
104-
llm_config.backend.qnn.optimized_rotation_path = args.optimized_rotation_path
105-
llm_config.backend.qnn.num_sharding = args.num_sharding
141+
if hasattr(args, "qnn"):
142+
llm_config.backend.qnn.enabled = args.qnn
143+
if hasattr(args, "use_qnn_sha"):
144+
llm_config.backend.qnn.use_sha = args.use_qnn_sha
145+
if hasattr(args, "soc_model"):
146+
llm_config.backend.qnn.soc_model = args.soc_model
147+
if hasattr(args, "optimized_rotation_path"):
148+
llm_config.backend.qnn.optimized_rotation_path = args.optimized_rotation_path
149+
if hasattr(args, "num_sharding"):
150+
llm_config.backend.qnn.num_sharding = args.num_sharding
106151

107152
# MPS
108-
llm_config.backend.mps.enabled = args.mps
153+
if hasattr(args, "mps"):
154+
llm_config.backend.mps.enabled = args.mps
109155

110156
# DebugConfig
111-
llm_config.debug.profile_memory = args.profile_memory
112-
llm_config.debug.profile_path = args.profile_path
113-
llm_config.debug.generate_etrecord = args.generate_etrecord
114-
llm_config.debug.generate_full_logits = args.generate_full_logits
115-
llm_config.debug.verbose = args.verbose
157+
if hasattr(args, "profile_memory"):
158+
llm_config.debug.profile_memory = args.profile_memory
159+
if hasattr(args, "profile_path"):
160+
llm_config.debug.profile_path = args.profile_path
161+
if hasattr(args, "generate_etrecord"):
162+
llm_config.debug.generate_etrecord = args.generate_etrecord
163+
if hasattr(args, "generate_full_logits"):
164+
llm_config.debug.generate_full_logits = args.generate_full_logits
165+
if hasattr(args, "verbose"):
166+
llm_config.debug.verbose = args.verbose
116167

117168
return llm_config

0 commit comments

Comments
 (0)