Skip to content

refactor: remove backward compatibility and simplify Llama2Model configuration #11169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,11 +1226,6 @@ def _load_llama_model(
EagerModelFactory.create_model(
module_name,
model_class_name,
checkpoint=checkpoint,
checkpoint_dir=checkpoint_dir,
params=params_path,
fairseq2=weight_type == WeightType.FAIRSEQ2,
dtype=torch_dtype,
llm_config=llm_config,
)
)
Expand Down
44 changes: 16 additions & 28 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,30 +47,18 @@ def __init__(self, **kwargs):
# Params file.
params_path = kwargs.get("params", None)

self.llm_config = kwargs.get("llm_config", None)
self.llm_config = kwargs.get("llm_config")
assert self.llm_config is not None, "llm_config must be provided"

# Set all parameters from llm_config if available, otherwise use kwargs as fallback
if self.llm_config:
self.use_kv_cache = self.llm_config.model.use_kv_cache
self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache
self.generate_full_logits = self.llm_config.debug.generate_full_logits
self.enable_dynamic_shape = self.llm_config.model.enable_dynamic_shape
self.input_prune_map_path = self.llm_config.model.input_prune_map
self.output_prune_map_path = self.llm_config.model.output_prune_map
self.max_seq_len = self.llm_config.export.max_seq_length
self.max_context_len = self.llm_config.export.max_context_length
self.verbose = self.llm_config.debug.verbose
else:
# Fallback to kwargs for backward compatibility
self.use_kv_cache = kwargs.get("use_kv_cache", False)
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
self.generate_full_logits = kwargs.get("generate_full_logits", False)
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
self.input_prune_map_path = kwargs.get("input_prune_map_path", None)
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
self.max_seq_len = kwargs.get("max_seq_len", 128)
self.max_context_len = kwargs.get("max_context_len", 128)
self.verbose = kwargs.get("verbose", False)
self.use_kv_cache = self.llm_config.model.use_kv_cache
self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache
self.generate_full_logits = self.llm_config.debug.generate_full_logits
self.enable_dynamic_shape = self.llm_config.model.enable_dynamic_shape
self.input_prune_map_path = self.llm_config.model.input_prune_map
self.output_prune_map_path = self.llm_config.model.output_prune_map
self.max_seq_len = self.llm_config.export.max_seq_length
self.max_context_len = self.llm_config.export.max_context_length
self.verbose = self.llm_config.debug.verbose

assert (
self.max_context_len >= self.max_seq_len
Expand Down Expand Up @@ -173,7 +161,7 @@ def __init__(self, **kwargs):

if model_args.use_scaled_rope:
# Older models don't have use_scaled_rope configuration
model_name = str(self.llm_config.base.model_class) if self.llm_config else "llama3"
model_name = str(self.llm_config.base.model_class)
assert model_name not in ["llama2", "stories110m"]

# Llama3_2 and newer models in ExecuTorch repo should set larger scale factor
Expand Down Expand Up @@ -212,7 +200,7 @@ def __init__(self, **kwargs):
self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime(
self.model_
)
elif self.llm_config and self.llm_config.quantization.use_spin_quant:
elif self.llm_config.quantization.use_spin_quant:
print("Using SPIN quantization.")
self._transform_for_pre_quantization(checkpoint, model_args)

Expand All @@ -221,7 +209,7 @@ def __init__(self, **kwargs):
)

sanitize_checkpoint_from_pre_quantization(checkpoint)
elif self.llm_config and self.llm_config.quantization.use_qat:
elif self.llm_config.quantization.use_qat:
print("Using QAT quantization.")
self._transform_for_pre_quantization(checkpoint, model_args)
if self.llm_config.base.use_lora:
Expand All @@ -243,7 +231,7 @@ def __init__(self, **kwargs):

sanitize_checkpoint_from_pre_quantization(checkpoint)

if self.llm_config and self.llm_config.model.use_attention_sink:
if self.llm_config.model.use_attention_sink:
from .source_transformation.attention_sink import enable_attention_sink

attention_sink_params = self.llm_config.model.use_attention_sink.split(",")
Expand Down Expand Up @@ -343,7 +331,7 @@ def get_example_inputs_kvcache_sdpa(self):
)

def _transform_for_pre_quantization(self, checkpoint, model_args):
assert self.llm_config and self.llm_config.base.preq_mode, "preq_mode must be specified"
assert self.llm_config.base.preq_mode, "preq_mode must be specified"
assert self.llm_config.base.preq_mode in [
"8da4w",
"8da4w_output_8da8w",
Expand Down
Loading