@@ -55,7 +55,7 @@ def __init__(self, **kwargs):
55
55
self .output_prune_map_path = kwargs .get ("output_prune_map_path" , None )
56
56
self .max_seq_len = kwargs .get ("max_seq_len" , 128 )
57
57
self .max_context_len = kwargs .get ("max_context_len" , 128 )
58
- self .args = kwargs .get ("args " , None )
58
+ self .llm_config = kwargs .get ("llm_config " , None )
59
59
60
60
assert (
61
61
self .max_context_len >= self .max_seq_len
@@ -158,10 +158,11 @@ def __init__(self, **kwargs):
158
158
159
159
if model_args .use_scaled_rope :
160
160
# Older models don't have use_scaled_rope configuration
161
- assert self .args .model not in ["llama2" , "stories110m" ]
161
+ model_name = str (self .llm_config .base .model_class ) if self .llm_config else "llama3"
162
+ assert model_name not in ["llama2" , "stories110m" ]
162
163
163
164
# Llama3_2 and newer models in ExecuTorch repo should set larger scale factor
164
- if self . args . model not in ["llama3" , "llama3_1" ]:
165
+ if model_name not in ["llama3" , "llama3_1" ]:
165
166
model_args .rope_scale_factor = 32
166
167
167
168
if kwargs .get ("verbose" , False ):
@@ -196,7 +197,7 @@ def __init__(self, **kwargs):
196
197
self .model_ = Int8DynActInt4WeightQuantizer ()._convert_for_runtime (
197
198
self .model_
198
199
)
199
- elif hasattr ( self .args , "use_spin_quant" ) and self .args .use_spin_quant :
200
+ elif self .llm_config and self .llm_config . quantization .use_spin_quant :
200
201
print ("Using SPIN quantization." )
201
202
self ._transform_for_pre_quantization (checkpoint , model_args )
202
203
@@ -205,19 +206,20 @@ def __init__(self, **kwargs):
205
206
)
206
207
207
208
sanitize_checkpoint_from_pre_quantization (checkpoint )
208
- elif hasattr ( self .args , "use_qat" ) and self .args .use_qat :
209
+ elif self .llm_config and self .llm_config . quantization .use_qat :
209
210
print ("Using QAT quantization." )
210
211
self ._transform_for_pre_quantization (checkpoint , model_args )
211
- if hasattr (self .args , "use_lora" ) and self .args .use_lora :
212
- assert model_args .lora_args ["rank" ] == self .args .use_lora
212
+ if self .llm_config .base .use_lora :
213
+ lora_rank = self .llm_config .base .use_lora
214
+ assert model_args .lora_args ["rank" ] == lora_rank
213
215
from .source_transformation .lora import (
214
216
transform_linear_for_lora_after_quantization ,
215
217
)
216
218
217
219
self .model_ = transform_linear_for_lora_after_quantization (
218
220
self .model_ ,
219
221
checkpoint ,
220
- self . args . use_lora ,
222
+ lora_rank ,
221
223
)
222
224
223
225
from .source_transformation .pre_quantization import (
@@ -226,16 +228,16 @@ def __init__(self, **kwargs):
226
228
227
229
sanitize_checkpoint_from_pre_quantization (checkpoint )
228
230
229
- if hasattr ( self .args , "use_attention_sink" ) and self .args .use_attention_sink :
231
+ if self .llm_config and self .llm_config . model .use_attention_sink :
230
232
from .source_transformation .attention_sink import enable_attention_sink
231
233
232
- attention_sink_params = self .args .use_attention_sink .split ("," )
234
+ attention_sink_params = self .llm_config . model .use_attention_sink .split ("," )
233
235
assert len (attention_sink_params ) == 3
234
236
sink_size = int (attention_sink_params [0 ])
235
237
window_size = int (attention_sink_params [1 ])
236
238
eviction_batch_size = int (attention_sink_params [2 ])
237
239
238
- assert self .args .max_context_length == sink_size + window_size
240
+ assert self .llm_config . export .max_context_length == sink_size + window_size
239
241
240
242
self .model_ = enable_attention_sink (
241
243
module = self .model_ ,
@@ -326,20 +328,19 @@ def get_example_inputs_kvcache_sdpa(self):
326
328
)
327
329
328
330
def _transform_for_pre_quantization (self , checkpoint , model_args ):
329
- assert hasattr ( self .args , " preq_mode" ) , "preq_mode must be specified"
330
- assert self .args .preq_mode in [
331
+ assert self .llm_config and self . llm_config . base . preq_mode , "preq_mode must be specified"
332
+ assert self .llm_config . base .preq_mode in [
331
333
"8da4w" ,
332
334
"8da4w_output_8da8w" ,
333
- ], f"Quantization mode { self .args .preq_mode } is not compatible with SpinQuant."
334
- assert hasattr (
335
- self .args , "preq_group_size"
336
- ), "preq_group_size must be specified"
337
- assert hasattr (self .args , "dtype_override" ), "dtype_override must be specified"
335
+ ], f"Quantization mode { self .llm_config .base .preq_mode } is not compatible with SpinQuant."
336
+ assert self .llm_config .base .preq_group_size , "preq_group_size must be specified"
337
+ assert self .llm_config .model .dtype_override , "dtype_override must be specified"
338
+
338
339
from .source_transformation .pre_quantization import (
339
340
transform_linear_for_pre_quantization ,
340
341
)
341
342
342
- assert self .args .preq_group_size == model_args .quantization_args ["group_size" ]
343
+ assert self .llm_config . base .preq_group_size == model_args .quantization_args ["group_size" ]
343
344
344
345
mapping = {
345
346
"fp32" : torch .float32 ,
@@ -348,28 +349,28 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
348
349
}
349
350
350
351
# Transform the output layer first if needed.
351
- if self .args .preq_mode == "8da4w_output_8da8w" :
352
+ if self .llm_config . base .preq_mode == "8da4w_output_8da8w" :
352
353
from .source_transformation .pre_quantization import (
353
354
transform_output_linear_for_pre_quantization ,
354
355
)
355
356
356
357
self .model_ = transform_output_linear_for_pre_quantization (
357
358
module = self .model_ ,
358
359
checkpoint = checkpoint ,
359
- dtype = mapping [self .args .dtype_override ],
360
+ dtype = mapping [self .llm_config . model .dtype_override ],
360
361
)
361
362
362
363
self .model_ = transform_linear_for_pre_quantization (
363
364
self .model_ ,
364
365
checkpoint ,
365
- self .args .preq_group_size ,
366
- mapping [self .args .dtype_override ],
366
+ self .llm_config . base .preq_group_size ,
367
+ mapping [self .llm_config . model .dtype_override ],
367
368
)
368
369
369
370
embedding_bit_width , embedding_group_size = None , None
370
- if hasattr ( self .args , " preq_embedding_quantize" ) :
371
+ if self .llm_config . base . preq_embedding_quantize :
371
372
embedding_bit_width , embedding_group_size = (
372
- self .args .preq_embedding_quantize .split ("," )
373
+ self .llm_config . base .preq_embedding_quantize .split ("," )
373
374
)
374
375
from .source_transformation .pre_quantization import (
375
376
transform_embedding_for_pre_quantization ,
@@ -387,7 +388,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
387
388
self .model_ = transform_embedding_for_pre_quantization (
388
389
self .model_ ,
389
390
checkpoint ,
390
- mapping [self .args .dtype_override ],
391
+ mapping [self .llm_config . model .dtype_override ],
391
392
int (embedding_bit_width ),
392
393
embedding_group_size ,
393
394
)
0 commit comments