Skip to content

Commit 8d55ee4

Browse files
committed
refactor: Use LlmConfig for model parameters instead of kwargs
ghstack-source-id: 7d2d0d0 Pull Request resolved: #11168
1 parent 11b2f0f commit 8d55ee4

File tree

2 files changed

+25
-18
lines changed

2 files changed

+25
-18
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,15 +1229,7 @@ def _load_llama_model(
12291229
checkpoint=checkpoint,
12301230
checkpoint_dir=checkpoint_dir,
12311231
params=params_path,
1232-
use_kv_cache=use_kv_cache,
1233-
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
1234-
generate_full_logits=generate_full_logits,
12351232
fairseq2=weight_type == WeightType.FAIRSEQ2,
1236-
max_seq_len=max_seq_len,
1237-
max_context_len=max_context_len,
1238-
enable_dynamic_shape=enable_dynamic_shape,
1239-
input_prune_map_path=input_prune_map_path,
1240-
output_prune_map_path=output_prune_map_path,
12411233
dtype=torch_dtype,
12421234
llm_config=llm_config,
12431235
)

examples/models/llama/model.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,30 @@ def __init__(self, **kwargs):
4747
# Params file.
4848
params_path = kwargs.get("params", None)
4949

50-
self.use_kv_cache = kwargs.get("use_kv_cache", False)
51-
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
52-
self.generate_full_logits = kwargs.get("generate_full_logits", False)
53-
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
54-
self.input_prune_map_path = kwargs.get("input_prune_map_path", None)
55-
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
56-
self.max_seq_len = kwargs.get("max_seq_len", 128)
57-
self.max_context_len = kwargs.get("max_context_len", 128)
5850
self.llm_config = kwargs.get("llm_config", None)
51+
52+
# Set all parameters from llm_config if available, otherwise use kwargs as fallback
53+
if self.llm_config:
54+
self.use_kv_cache = self.llm_config.model.use_kv_cache
55+
self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache
56+
self.generate_full_logits = self.llm_config.debug.generate_full_logits
57+
self.enable_dynamic_shape = self.llm_config.model.enable_dynamic_shape
58+
self.input_prune_map_path = self.llm_config.model.input_prune_map
59+
self.output_prune_map_path = self.llm_config.model.output_prune_map
60+
self.max_seq_len = self.llm_config.export.max_seq_length
61+
self.max_context_len = self.llm_config.export.max_context_length
62+
self.verbose = self.llm_config.debug.verbose
63+
else:
64+
# Fallback to kwargs for backward compatibility
65+
self.use_kv_cache = kwargs.get("use_kv_cache", False)
66+
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
67+
self.generate_full_logits = kwargs.get("generate_full_logits", False)
68+
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
69+
self.input_prune_map_path = kwargs.get("input_prune_map_path", None)
70+
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
71+
self.max_seq_len = kwargs.get("max_seq_len", 128)
72+
self.max_context_len = kwargs.get("max_context_len", 128)
73+
self.verbose = kwargs.get("verbose", False)
5974

6075
assert (
6176
self.max_context_len >= self.max_seq_len
@@ -165,7 +180,7 @@ def __init__(self, **kwargs):
165180
if model_name not in ["llama3", "llama3_1"]:
166181
model_args.rope_scale_factor = 32
167182

168-
if kwargs.get("verbose", False):
183+
if self.verbose:
169184
print("============= weights ================")
170185
print("{key} : {weights.numel()} : {weights.size()}")
171186
for key, weights in checkpoint.items():
@@ -280,7 +295,7 @@ def __init__(self, **kwargs):
280295
f"The provided checkpoint is missing the following weights that are expected by the model: {missing_weights}. Please fix the fqn's in your checkpoint to match."
281296
)
282297
if unexpected:
283-
if kwargs.get("verbose", False):
298+
if self.verbose:
284299
print(f"Unexpected keys: {unexpected}")
285300

286301
# Prune the input layer if input_prune_map is provided

0 commit comments

Comments
 (0)