Skip to content

Commit 14df400

Browse files
committed
refactor: remove backward compatibility and simplify Llama2Model configuration
ghstack-source-id: 6017106 Pull Request resolved: pytorch/executorch#11169
1 parent 8d55ee4 commit 14df400

File tree

2 files changed

+16
-33
lines changed

2 files changed

+16
-33
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,11 +1226,6 @@ def _load_llama_model(
12261226
EagerModelFactory.create_model(
12271227
module_name,
12281228
model_class_name,
1229-
checkpoint=checkpoint,
1230-
checkpoint_dir=checkpoint_dir,
1231-
params=params_path,
1232-
fairseq2=weight_type == WeightType.FAIRSEQ2,
1233-
dtype=torch_dtype,
12341229
llm_config=llm_config,
12351230
)
12361231
)

examples/models/llama/model.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -47,30 +47,18 @@ def __init__(self, **kwargs):
4747
# Params file.
4848
params_path = kwargs.get("params", None)
4949

50-
self.llm_config = kwargs.get("llm_config", None)
50+
self.llm_config = kwargs.get("llm_config")
51+
assert self.llm_config is not None, "llm_config must be provided"
5152

52-
# Set all parameters from llm_config if available, otherwise use kwargs as fallback
53-
if self.llm_config:
54-
self.use_kv_cache = self.llm_config.model.use_kv_cache
55-
self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache
56-
self.generate_full_logits = self.llm_config.debug.generate_full_logits
57-
self.enable_dynamic_shape = self.llm_config.model.enable_dynamic_shape
58-
self.input_prune_map_path = self.llm_config.model.input_prune_map
59-
self.output_prune_map_path = self.llm_config.model.output_prune_map
60-
self.max_seq_len = self.llm_config.export.max_seq_length
61-
self.max_context_len = self.llm_config.export.max_context_length
62-
self.verbose = self.llm_config.debug.verbose
63-
else:
64-
# Fallback to kwargs for backward compatibility
65-
self.use_kv_cache = kwargs.get("use_kv_cache", False)
66-
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
67-
self.generate_full_logits = kwargs.get("generate_full_logits", False)
68-
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
69-
self.input_prune_map_path = kwargs.get("input_prune_map_path", None)
70-
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
71-
self.max_seq_len = kwargs.get("max_seq_len", 128)
72-
self.max_context_len = kwargs.get("max_context_len", 128)
73-
self.verbose = kwargs.get("verbose", False)
53+
self.use_kv_cache = self.llm_config.model.use_kv_cache
54+
self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache
55+
self.generate_full_logits = self.llm_config.debug.generate_full_logits
56+
self.enable_dynamic_shape = self.llm_config.model.enable_dynamic_shape
57+
self.input_prune_map_path = self.llm_config.model.input_prune_map
58+
self.output_prune_map_path = self.llm_config.model.output_prune_map
59+
self.max_seq_len = self.llm_config.export.max_seq_length
60+
self.max_context_len = self.llm_config.export.max_context_length
61+
self.verbose = self.llm_config.debug.verbose
7462

7563
assert (
7664
self.max_context_len >= self.max_seq_len
@@ -173,7 +161,7 @@ def __init__(self, **kwargs):
173161

174162
if model_args.use_scaled_rope:
175163
# Older models don't have use_scaled_rope configuration
176-
model_name = str(self.llm_config.base.model_class) if self.llm_config else "llama3"
164+
model_name = str(self.llm_config.base.model_class)
177165
assert model_name not in ["llama2", "stories110m"]
178166

179167
# Llama3_2 and newer models in ExecuTorch repo should set larger scale factor
@@ -212,7 +200,7 @@ def __init__(self, **kwargs):
212200
self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime(
213201
self.model_
214202
)
215-
elif self.llm_config and self.llm_config.quantization.use_spin_quant:
203+
elif self.llm_config.quantization.use_spin_quant:
216204
print("Using SPIN quantization.")
217205
self._transform_for_pre_quantization(checkpoint, model_args)
218206

@@ -221,7 +209,7 @@ def __init__(self, **kwargs):
221209
)
222210

223211
sanitize_checkpoint_from_pre_quantization(checkpoint)
224-
elif self.llm_config and self.llm_config.quantization.use_qat:
212+
elif self.llm_config.quantization.use_qat:
225213
print("Using QAT quantization.")
226214
self._transform_for_pre_quantization(checkpoint, model_args)
227215
if self.llm_config.base.use_lora:
@@ -243,7 +231,7 @@ def __init__(self, **kwargs):
243231

244232
sanitize_checkpoint_from_pre_quantization(checkpoint)
245233

246-
if self.llm_config and self.llm_config.model.use_attention_sink:
234+
if self.llm_config.model.use_attention_sink:
247235
from .source_transformation.attention_sink import enable_attention_sink
248236

249237
attention_sink_params = self.llm_config.model.use_attention_sink.split(",")
@@ -343,7 +331,7 @@ def get_example_inputs_kvcache_sdpa(self):
343331
)
344332

345333
def _transform_for_pre_quantization(self, checkpoint, model_args):
346-
assert self.llm_config and self.llm_config.base.preq_mode, "preq_mode must be specified"
334+
assert self.llm_config.base.preq_mode, "preq_mode must be specified"
347335
assert self.llm_config.base.preq_mode in [
348336
"8da4w",
349337
"8da4w_output_8da8w",

0 commit comments

Comments
 (0)