Skip to content

Commit e32ea56

Browse files
committed
refactor: Replace self.args with LlmConfig in model.py and export_llama_lib.py
ghstack-source-id: f1f21a3 Pull Request resolved: #11166
1 parent 386bb05 commit e32ea56

File tree

2 files changed

+28
-27
lines changed

2 files changed

+28
-27
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1237,7 +1237,7 @@ def _load_llama_model(
12371237
input_prune_map_path=input_prune_map_path,
12381238
output_prune_map_path=output_prune_map_path,
12391239
dtype=torch_dtype,
1240-
args=args,
1240+
llm_config=llm_config,
12411241
)
12421242
)
12431243

examples/models/llama/model.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(self, **kwargs):
5555
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
5656
self.max_seq_len = kwargs.get("max_seq_len", 128)
5757
self.max_context_len = kwargs.get("max_context_len", 128)
58-
self.args = kwargs.get("args", None)
58+
self.llm_config = kwargs.get("llm_config", None)
5959

6060
assert (
6161
self.max_context_len >= self.max_seq_len
@@ -158,10 +158,11 @@ def __init__(self, **kwargs):
158158

159159
if model_args.use_scaled_rope:
160160
# Older models don't have use_scaled_rope configuration
161-
assert self.args.model not in ["llama2", "stories110m"]
161+
model_name = str(self.llm_config.base.model_class) if self.llm_config else "llama3"
162+
assert model_name not in ["llama2", "stories110m"]
162163

163164
# Llama3_2 and newer models in ExecuTorch repo should set larger scale factor
164-
if self.args.model not in ["llama3", "llama3_1"]:
165+
if model_name not in ["llama3", "llama3_1"]:
165166
model_args.rope_scale_factor = 32
166167

167168
if kwargs.get("verbose", False):
@@ -196,7 +197,7 @@ def __init__(self, **kwargs):
196197
self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime(
197198
self.model_
198199
)
199-
elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
200+
elif self.llm_config and self.llm_config.quantization.use_spin_quant:
200201
print("Using SPIN quantization.")
201202
self._transform_for_pre_quantization(checkpoint, model_args)
202203

@@ -205,19 +206,20 @@ def __init__(self, **kwargs):
205206
)
206207

207208
sanitize_checkpoint_from_pre_quantization(checkpoint)
208-
elif hasattr(self.args, "use_qat") and self.args.use_qat:
209+
elif self.llm_config and self.llm_config.quantization.use_qat:
209210
print("Using QAT quantization.")
210211
self._transform_for_pre_quantization(checkpoint, model_args)
211-
if hasattr(self.args, "use_lora") and self.args.use_lora:
212-
assert model_args.lora_args["rank"] == self.args.use_lora
212+
if self.llm_config.base.use_lora:
213+
lora_rank = self.llm_config.base.use_lora
214+
assert model_args.lora_args["rank"] == lora_rank
213215
from .source_transformation.lora import (
214216
transform_linear_for_lora_after_quantization,
215217
)
216218

217219
self.model_ = transform_linear_for_lora_after_quantization(
218220
self.model_,
219221
checkpoint,
220-
self.args.use_lora,
222+
lora_rank,
221223
)
222224

223225
from .source_transformation.pre_quantization import (
@@ -226,16 +228,16 @@ def __init__(self, **kwargs):
226228

227229
sanitize_checkpoint_from_pre_quantization(checkpoint)
228230

229-
if hasattr(self.args, "use_attention_sink") and self.args.use_attention_sink:
231+
if self.llm_config and self.llm_config.model.use_attention_sink:
230232
from .source_transformation.attention_sink import enable_attention_sink
231233

232-
attention_sink_params = self.args.use_attention_sink.split(",")
234+
attention_sink_params = self.llm_config.model.use_attention_sink.split(",")
233235
assert len(attention_sink_params) == 3
234236
sink_size = int(attention_sink_params[0])
235237
window_size = int(attention_sink_params[1])
236238
eviction_batch_size = int(attention_sink_params[2])
237239

238-
assert self.args.max_context_length == sink_size + window_size
240+
assert self.llm_config.export.max_context_length == sink_size + window_size
239241

240242
self.model_ = enable_attention_sink(
241243
module=self.model_,
@@ -326,20 +328,19 @@ def get_example_inputs_kvcache_sdpa(self):
326328
)
327329

328330
def _transform_for_pre_quantization(self, checkpoint, model_args):
329-
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"
330-
assert self.args.preq_mode in [
331+
assert self.llm_config and self.llm_config.base.preq_mode, "preq_mode must be specified"
332+
assert self.llm_config.base.preq_mode in [
331333
"8da4w",
332334
"8da4w_output_8da8w",
333-
], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant."
334-
assert hasattr(
335-
self.args, "preq_group_size"
336-
), "preq_group_size must be specified"
337-
assert hasattr(self.args, "dtype_override"), "dtype_override must be specified"
335+
], f"Quantization mode {self.llm_config.base.preq_mode} is not compatible with SpinQuant."
336+
assert self.llm_config.base.preq_group_size, "preq_group_size must be specified"
337+
assert self.llm_config.model.dtype_override, "dtype_override must be specified"
338+
338339
from .source_transformation.pre_quantization import (
339340
transform_linear_for_pre_quantization,
340341
)
341342

342-
assert self.args.preq_group_size == model_args.quantization_args["group_size"]
343+
assert self.llm_config.base.preq_group_size == model_args.quantization_args["group_size"]
343344

344345
mapping = {
345346
"fp32": torch.float32,
@@ -348,28 +349,28 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
348349
}
349350

350351
# Transform the output layer first if needed.
351-
if self.args.preq_mode == "8da4w_output_8da8w":
352+
if self.llm_config.base.preq_mode == "8da4w_output_8da8w":
352353
from .source_transformation.pre_quantization import (
353354
transform_output_linear_for_pre_quantization,
354355
)
355356

356357
self.model_ = transform_output_linear_for_pre_quantization(
357358
module=self.model_,
358359
checkpoint=checkpoint,
359-
dtype=mapping[self.args.dtype_override],
360+
dtype=mapping[self.llm_config.model.dtype_override],
360361
)
361362

362363
self.model_ = transform_linear_for_pre_quantization(
363364
self.model_,
364365
checkpoint,
365-
self.args.preq_group_size,
366-
mapping[self.args.dtype_override],
366+
self.llm_config.base.preq_group_size,
367+
mapping[self.llm_config.model.dtype_override],
367368
)
368369

369370
embedding_bit_width, embedding_group_size = None, None
370-
if hasattr(self.args, "preq_embedding_quantize"):
371+
if self.llm_config.base.preq_embedding_quantize:
371372
embedding_bit_width, embedding_group_size = (
372-
self.args.preq_embedding_quantize.split(",")
373+
self.llm_config.base.preq_embedding_quantize.split(",")
373374
)
374375
from .source_transformation.pre_quantization import (
375376
transform_embedding_for_pre_quantization,
@@ -387,7 +388,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
387388
self.model_ = transform_embedding_for_pre_quantization(
388389
self.model_,
389390
checkpoint,
390-
mapping[self.args.dtype_override],
391+
mapping[self.llm_config.model.dtype_override],
391392
int(embedding_bit_width),
392393
embedding_group_size,
393394
)

0 commit comments

Comments
 (0)