Skip to content

refactor: Replace self.args with LlmConfig in model.py and export_llama_lib.py #11166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@ def _load_llama_model(
input_prune_map_path=input_prune_map_path,
output_prune_map_path=output_prune_map_path,
dtype=torch_dtype,
args=args,
llm_config=llm_config,
)
)

Expand Down
53 changes: 27 additions & 26 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, **kwargs):
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
self.max_seq_len = kwargs.get("max_seq_len", 128)
self.max_context_len = kwargs.get("max_context_len", 128)
self.args = kwargs.get("args", None)
self.llm_config = kwargs.get("llm_config", None)

assert (
self.max_context_len >= self.max_seq_len
Expand Down Expand Up @@ -158,10 +158,11 @@ def __init__(self, **kwargs):

if model_args.use_scaled_rope:
# Older models don't have use_scaled_rope configuration
assert self.args.model not in ["llama2", "stories110m"]
model_name = str(self.llm_config.base.model_class) if self.llm_config else "llama3"
assert model_name not in ["llama2", "stories110m"]

# Llama3_2 and newer models in ExecuTorch repo should set larger scale factor
if self.args.model not in ["llama3", "llama3_1"]:
if model_name not in ["llama3", "llama3_1"]:
model_args.rope_scale_factor = 32

if kwargs.get("verbose", False):
Expand Down Expand Up @@ -196,7 +197,7 @@ def __init__(self, **kwargs):
self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime(
self.model_
)
elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
elif self.llm_config and self.llm_config.quantization.use_spin_quant:
print("Using SPIN quantization.")
self._transform_for_pre_quantization(checkpoint, model_args)

Expand All @@ -205,19 +206,20 @@ def __init__(self, **kwargs):
)

sanitize_checkpoint_from_pre_quantization(checkpoint)
elif hasattr(self.args, "use_qat") and self.args.use_qat:
elif self.llm_config and self.llm_config.quantization.use_qat:
print("Using QAT quantization.")
self._transform_for_pre_quantization(checkpoint, model_args)
if hasattr(self.args, "use_lora") and self.args.use_lora:
assert model_args.lora_args["rank"] == self.args.use_lora
if self.llm_config.base.use_lora:
lora_rank = self.llm_config.base.use_lora
assert model_args.lora_args["rank"] == lora_rank
from .source_transformation.lora import (
transform_linear_for_lora_after_quantization,
)

self.model_ = transform_linear_for_lora_after_quantization(
self.model_,
checkpoint,
self.args.use_lora,
lora_rank,
)

from .source_transformation.pre_quantization import (
Expand All @@ -226,16 +228,16 @@ def __init__(self, **kwargs):

sanitize_checkpoint_from_pre_quantization(checkpoint)

if hasattr(self.args, "use_attention_sink") and self.args.use_attention_sink:
if self.llm_config and self.llm_config.model.use_attention_sink:
from .source_transformation.attention_sink import enable_attention_sink

attention_sink_params = self.args.use_attention_sink.split(",")
attention_sink_params = self.llm_config.model.use_attention_sink.split(",")
assert len(attention_sink_params) == 3
sink_size = int(attention_sink_params[0])
window_size = int(attention_sink_params[1])
eviction_batch_size = int(attention_sink_params[2])

assert self.args.max_context_length == sink_size + window_size
assert self.llm_config.export.max_context_length == sink_size + window_size

self.model_ = enable_attention_sink(
module=self.model_,
Expand Down Expand Up @@ -326,20 +328,19 @@ def get_example_inputs_kvcache_sdpa(self):
)

def _transform_for_pre_quantization(self, checkpoint, model_args):
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"
assert self.args.preq_mode in [
assert self.llm_config and self.llm_config.base.preq_mode, "preq_mode must be specified"
assert self.llm_config.base.preq_mode in [
"8da4w",
"8da4w_output_8da8w",
], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant."
assert hasattr(
self.args, "preq_group_size"
), "preq_group_size must be specified"
assert hasattr(self.args, "dtype_override"), "dtype_override must be specified"
], f"Quantization mode {self.llm_config.base.preq_mode} is not compatible with SpinQuant."
assert self.llm_config.base.preq_group_size, "preq_group_size must be specified"
assert self.llm_config.model.dtype_override, "dtype_override must be specified"

from .source_transformation.pre_quantization import (
transform_linear_for_pre_quantization,
)

assert self.args.preq_group_size == model_args.quantization_args["group_size"]
assert self.llm_config.base.preq_group_size == model_args.quantization_args["group_size"]

mapping = {
"fp32": torch.float32,
Expand All @@ -348,28 +349,28 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
}

# Transform the output layer first if needed.
if self.args.preq_mode == "8da4w_output_8da8w":
if self.llm_config.base.preq_mode == "8da4w_output_8da8w":
from .source_transformation.pre_quantization import (
transform_output_linear_for_pre_quantization,
)

self.model_ = transform_output_linear_for_pre_quantization(
module=self.model_,
checkpoint=checkpoint,
dtype=mapping[self.args.dtype_override],
dtype=mapping[self.llm_config.model.dtype_override],
)

self.model_ = transform_linear_for_pre_quantization(
self.model_,
checkpoint,
self.args.preq_group_size,
mapping[self.args.dtype_override],
self.llm_config.base.preq_group_size,
mapping[self.llm_config.model.dtype_override],
)

embedding_bit_width, embedding_group_size = None, None
if hasattr(self.args, "preq_embedding_quantize"):
if self.llm_config.base.preq_embedding_quantize:
embedding_bit_width, embedding_group_size = (
self.args.preq_embedding_quantize.split(",")
self.llm_config.base.preq_embedding_quantize.split(",")
)
from .source_transformation.pre_quantization import (
transform_embedding_for_pre_quantization,
Expand All @@ -387,7 +388,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
self.model_ = transform_embedding_for_pre_quantization(
self.model_,
checkpoint,
mapping[self.args.dtype_override],
mapping[self.llm_config.model.dtype_override],
int(embedding_bit_width),
embedding_group_size,
)
Loading