@@ -47,30 +47,18 @@ def __init__(self, **kwargs):
47
47
# Params file.
48
48
params_path = kwargs .get ("params" , None )
49
49
50
- self .llm_config = kwargs .get ("llm_config" , None )
50
+ self .llm_config = kwargs .get ("llm_config" )
51
+ assert self .llm_config is not None , "llm_config must be provided"
51
52
52
- # Set all parameters from llm_config if available, otherwise use kwargs as fallback
53
- if self .llm_config :
54
- self .use_kv_cache = self .llm_config .model .use_kv_cache
55
- self .use_sdpa_with_kv_cache_op = self .llm_config .model .use_sdpa_with_kv_cache
56
- self .generate_full_logits = self .llm_config .debug .generate_full_logits
57
- self .enable_dynamic_shape = self .llm_config .model .enable_dynamic_shape
58
- self .input_prune_map_path = self .llm_config .model .input_prune_map
59
- self .output_prune_map_path = self .llm_config .model .output_prune_map
60
- self .max_seq_len = self .llm_config .export .max_seq_length
61
- self .max_context_len = self .llm_config .export .max_context_length
62
- self .verbose = self .llm_config .debug .verbose
63
- else :
64
- # Fallback to kwargs for backward compatibility
65
- self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
66
- self .use_sdpa_with_kv_cache_op = kwargs .get ("use_sdpa_with_kv_cache" , False )
67
- self .generate_full_logits = kwargs .get ("generate_full_logits" , False )
68
- self .enable_dynamic_shape = kwargs .get ("enable_dynamic_shape" , False )
69
- self .input_prune_map_path = kwargs .get ("input_prune_map_path" , None )
70
- self .output_prune_map_path = kwargs .get ("output_prune_map_path" , None )
71
- self .max_seq_len = kwargs .get ("max_seq_len" , 128 )
72
- self .max_context_len = kwargs .get ("max_context_len" , 128 )
73
- self .verbose = kwargs .get ("verbose" , False )
53
+ self .use_kv_cache = self .llm_config .model .use_kv_cache
54
+ self .use_sdpa_with_kv_cache_op = self .llm_config .model .use_sdpa_with_kv_cache
55
+ self .generate_full_logits = self .llm_config .debug .generate_full_logits
56
+ self .enable_dynamic_shape = self .llm_config .model .enable_dynamic_shape
57
+ self .input_prune_map_path = self .llm_config .model .input_prune_map
58
+ self .output_prune_map_path = self .llm_config .model .output_prune_map
59
+ self .max_seq_len = self .llm_config .export .max_seq_length
60
+ self .max_context_len = self .llm_config .export .max_context_length
61
+ self .verbose = self .llm_config .debug .verbose
74
62
75
63
assert (
76
64
self .max_context_len >= self .max_seq_len
@@ -173,7 +161,7 @@ def __init__(self, **kwargs):
173
161
174
162
if model_args .use_scaled_rope :
175
163
# Older models don't have use_scaled_rope configuration
176
- model_name = str (self .llm_config .base .model_class ) if self . llm_config else "llama3"
164
+ model_name = str (self .llm_config .base .model_class )
177
165
assert model_name not in ["llama2" , "stories110m" ]
178
166
179
167
# Llama3_2 and newer models in ExecuTorch repo should set larger scale factor
@@ -212,7 +200,7 @@ def __init__(self, **kwargs):
212
200
self .model_ = Int8DynActInt4WeightQuantizer ()._convert_for_runtime (
213
201
self .model_
214
202
)
215
- elif self .llm_config and self . llm_config .quantization .use_spin_quant :
203
+ elif self .llm_config .quantization .use_spin_quant :
216
204
print ("Using SPIN quantization." )
217
205
self ._transform_for_pre_quantization (checkpoint , model_args )
218
206
@@ -221,7 +209,7 @@ def __init__(self, **kwargs):
221
209
)
222
210
223
211
sanitize_checkpoint_from_pre_quantization (checkpoint )
224
- elif self .llm_config and self . llm_config .quantization .use_qat :
212
+ elif self .llm_config .quantization .use_qat :
225
213
print ("Using QAT quantization." )
226
214
self ._transform_for_pre_quantization (checkpoint , model_args )
227
215
if self .llm_config .base .use_lora :
@@ -243,7 +231,7 @@ def __init__(self, **kwargs):
243
231
244
232
sanitize_checkpoint_from_pre_quantization (checkpoint )
245
233
246
- if self .llm_config and self . llm_config .model .use_attention_sink :
234
+ if self .llm_config .model .use_attention_sink :
247
235
from .source_transformation .attention_sink import enable_attention_sink
248
236
249
237
attention_sink_params = self .llm_config .model .use_attention_sink .split ("," )
@@ -343,7 +331,7 @@ def get_example_inputs_kvcache_sdpa(self):
343
331
)
344
332
345
333
def _transform_for_pre_quantization (self , checkpoint , model_args ):
346
- assert self .llm_config and self . llm_config .base .preq_mode , "preq_mode must be specified"
334
+ assert self .llm_config .base .preq_mode , "preq_mode must be specified"
347
335
assert self .llm_config .base .preq_mode in [
348
336
"8da4w" ,
349
337
"8da4w_output_8da8w" ,
0 commit comments