@@ -165,7 +165,7 @@ def __init__(self, **kwargs):
165
165
)
166
166
elif hasattr (self .args , "use_spin_quant" ) and self .args .use_spin_quant :
167
167
print ("Using SPIN quantization." )
168
- self ._transform_for_pre_quantization (checkpoint )
168
+ self ._transform_for_pre_quantization (checkpoint , model_args )
169
169
170
170
from .source_transformation .pre_quantization import (
171
171
sanitize_checkpoint_from_pre_quantization ,
@@ -174,8 +174,9 @@ def __init__(self, **kwargs):
174
174
sanitize_checkpoint_from_pre_quantization (checkpoint )
175
175
elif hasattr (self .args , "use_qat" ) and self .args .use_qat :
176
176
print ("Using QAT quantization." )
177
- self ._transform_for_pre_quantization (checkpoint )
177
+ self ._transform_for_pre_quantization (checkpoint , model_args )
178
178
if hasattr (self .args , "use_lora" ) and self .args .use_lora :
179
+ assert model_args .lora_args ["rank" ] == self .args .use_lora
179
180
from .source_transformation .lora import (
180
181
transform_linear_for_lora_after_quantization ,
181
182
)
@@ -251,7 +252,7 @@ def get_example_inputs_kvcache_sdpa(self):
251
252
), # start_pos, what token of output are we on.
252
253
)
253
254
254
- def _transform_for_pre_quantization (self , checkpoint ):
255
+ def _transform_for_pre_quantization (self , checkpoint , model_args ):
255
256
assert hasattr (self .args , "preq_mode" ), "preq_mode must be specified"
256
257
assert self .args .preq_mode in [
257
258
"8da4w" ,
@@ -265,6 +266,8 @@ def _transform_for_pre_quantization(self, checkpoint):
265
266
transform_linear_for_pre_quantization ,
266
267
)
267
268
269
+ assert self .args .preq_group_size == model_args .quantization_args ["group_size" ]
270
+
268
271
mapping = {
269
272
"fp32" : torch .float32 ,
270
273
"fp16" : torch .float16 ,
0 commit comments