@@ -102,7 +102,7 @@ def _replace_linear_with_linear_8da4w_for_spin_quant(
102
102
):
103
103
def filter_fn (child : torch .nn .Module , cur_fqn : str ) -> bool :
104
104
# Only replace linear layers where the checkpoint contains explicit scales
105
- scales_key = f"{ cur_fqn } .scale "
105
+ scales_key = f"{ cur_fqn } .scales "
106
106
if isinstance (child , nn .Linear ) and scales_key in checkpoint :
107
107
assert _check_linear_int4_k (child .in_features , group_size )
108
108
assert checkpoint [f"{ cur_fqn } .weight" ].dtype == torch .int8
@@ -155,7 +155,7 @@ def _replace_output_linear_with_linear_int8_for_spinquant(
155
155
dtype : torch .dtype ,
156
156
):
157
157
def filter_fn (child : torch .nn .Module , cur_fqn : str ) -> bool :
158
- scales_key = f"{ cur_fqn } .scale "
158
+ scales_key = f"{ cur_fqn } .scales "
159
159
if (
160
160
isinstance (child , nn .Linear )
161
161
and scales_key in checkpoint
@@ -205,7 +205,7 @@ def _replace_embedding_with_quantized_group_embedding_for_spinquant(
205
205
):
206
206
def filter_fn (child : torch .nn .Module , cur_fqn : str ) -> bool :
207
207
# Only replace embedding layers where the checkpoint contains explicit scales
208
- scales_key = f"{ cur_fqn } .scale "
208
+ scales_key = f"{ cur_fqn } .scales "
209
209
if isinstance (child , nn .Embedding ) and scales_key in checkpoint :
210
210
assert checkpoint [f"{ cur_fqn } .weight" ].dtype == torch .int8
211
211
assert checkpoint [scales_key ].dtype == torch .float32
@@ -250,59 +250,12 @@ def transform_embedding_for_spinquant(
250
250
251
251
252
252
def sanitize_checkpoint_from_spinquant (
253
- module : torch .nn .Module ,
254
253
checkpoint : Any ,
255
- linear_group_size : int ,
256
- embedding_group_size : Optional [int ] = None ,
257
254
):
258
255
"""
259
256
Sanitize the SpinQuant checkpoint.
260
- - Renames 'scale' to 'scales'
261
- - Groups scales
262
- - Removes 'o_weight'
263
257
- Converts all tensors to contiguous format
258
+ - Squeeze all tensors
264
259
"""
265
- keys_to_rename = []
266
- keys_to_remove = []
267
- for k , _ in checkpoint .items ():
268
- if k .endswith (".scale" ):
269
- new_key = k + "s"
270
- keys_to_rename .append ((k , new_key ))
271
- if k .endswith (".o_weight" ):
272
- keys_to_remove .append (k )
273
-
274
- for old_key , new_key in keys_to_rename :
275
- old_val = checkpoint .pop (old_key )
276
- module_name = new_key [0 : new_key .rfind ("." )]
277
- sub_module = module .get_submodule (module_name )
278
- assert sub_module is not None
279
- assert (
280
- isinstance (sub_module , Int8DynActInt4WeightLinear )
281
- or isinstance (sub_module , QuantizedGroupEmbedding )
282
- or isinstance (sub_module , Int8DynActInt8WeightLinear )
283
- )
284
- # Checkpoints with SpinQuant could come with two formats for scales:
285
- # 1. scales is grouped by group size
286
- # 2. scales is not grouped by group size
287
- # We need to handle both cases here.
288
- # TODO(lunwenh): remove this once we have a unified format for scales.
289
- if isinstance (sub_module , Int8DynActInt4WeightLinear ):
290
- checkpoint [new_key ] = (
291
- old_val if linear_group_size == - 1 else old_val [:, ::linear_group_size ]
292
- )
293
- elif isinstance (sub_module , Int8DynActInt8WeightLinear ):
294
- checkpoint [new_key ] = old_val [:, 0 ]
295
- elif isinstance (sub_module , QuantizedGroupEmbedding ):
296
- if (
297
- embedding_group_size is None or embedding_group_size == 0
298
- ): # Scales are not grouped
299
- checkpoint [new_key ] = old_val [:, 0 ]
300
- elif embedding_group_size == - 1 : # Scales are grouped by group size
301
- checkpoint [new_key ] = old_val
302
- else :
303
- checkpoint [new_key ] = old_val [:, ::embedding_group_size ]
304
-
305
- for k in keys_to_remove :
306
- checkpoint .pop (k )
307
260
for k , v in checkpoint .items ():
308
- checkpoint [k ] = v .contiguous ()
261
+ checkpoint [k ] = torch . squeeze ( v .contiguous () )
0 commit comments