Skip to content

Commit 52d5218

Browse files
Lunwen Hefacebook-github-bot
authored andcommitted
update export SpinQuant checkpoint to align with the new format (#5645)
Summary: Pull Request resolved: #5645 Per our new aligned checkpoint format with SpinQuant: - Original weights: drop it from the checkpoint - Int4 or int8 weight: .weight - Scales: .scales, grouped by group size for attention linear layers, grouped by per channel for embedding and output layer. This PR updates the export flow to follow this new format. Reviewed By: mergennachin Differential Revision: D63402708 fbshipit-source-id: e06a45b95ad7628732cfadb803a16f650042cd97
1 parent 7ab977e commit 52d5218

File tree

3 files changed

+12
-77
lines changed

3 files changed

+12
-77
lines changed

examples/models/llama2/model.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -258,12 +258,7 @@ def __init__(self, **kwargs):
258258
embedding_group_size,
259259
)
260260

261-
sanitize_checkpoint_from_spinquant(
262-
module=self.model_,
263-
checkpoint=checkpoint,
264-
linear_group_size=self.args.spin_group_size,
265-
embedding_group_size=embedding_group_size,
266-
)
261+
sanitize_checkpoint_from_spinquant(checkpoint)
267262

268263
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
269264
# Because we are using device="meta", tensors do not have memory associated with them

examples/models/llama2/source_transformation/spin_quant.py

Lines changed: 5 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _replace_linear_with_linear_8da4w_for_spin_quant(
102102
):
103103
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
104104
# Only replace linear layers where the checkpoint contains explicit scales
105-
scales_key = f"{cur_fqn}.scale"
105+
scales_key = f"{cur_fqn}.scales"
106106
if isinstance(child, nn.Linear) and scales_key in checkpoint:
107107
assert _check_linear_int4_k(child.in_features, group_size)
108108
assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
@@ -155,7 +155,7 @@ def _replace_output_linear_with_linear_int8_for_spinquant(
155155
dtype: torch.dtype,
156156
):
157157
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
158-
scales_key = f"{cur_fqn}.scale"
158+
scales_key = f"{cur_fqn}.scales"
159159
if (
160160
isinstance(child, nn.Linear)
161161
and scales_key in checkpoint
@@ -205,7 +205,7 @@ def _replace_embedding_with_quantized_group_embedding_for_spinquant(
205205
):
206206
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
207207
# Only replace embedding layers where the checkpoint contains explicit scales
208-
scales_key = f"{cur_fqn}.scale"
208+
scales_key = f"{cur_fqn}.scales"
209209
if isinstance(child, nn.Embedding) and scales_key in checkpoint:
210210
assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
211211
assert checkpoint[scales_key].dtype == torch.float32
@@ -250,59 +250,12 @@ def transform_embedding_for_spinquant(
250250

251251

252252
def sanitize_checkpoint_from_spinquant(
253-
module: torch.nn.Module,
254253
checkpoint: Any,
255-
linear_group_size: int,
256-
embedding_group_size: Optional[int] = None,
257254
):
258255
"""
259256
Sanitize the SpinQuant checkpoint.
260-
- Renames 'scale' to 'scales'
261-
- Groups scales
262-
- Removes 'o_weight'
263257
- Converts all tensors to contiguous format
258+
- Squeeze all tensors
264259
"""
265-
keys_to_rename = []
266-
keys_to_remove = []
267-
for k, _ in checkpoint.items():
268-
if k.endswith(".scale"):
269-
new_key = k + "s"
270-
keys_to_rename.append((k, new_key))
271-
if k.endswith(".o_weight"):
272-
keys_to_remove.append(k)
273-
274-
for old_key, new_key in keys_to_rename:
275-
old_val = checkpoint.pop(old_key)
276-
module_name = new_key[0 : new_key.rfind(".")]
277-
sub_module = module.get_submodule(module_name)
278-
assert sub_module is not None
279-
assert (
280-
isinstance(sub_module, Int8DynActInt4WeightLinear)
281-
or isinstance(sub_module, QuantizedGroupEmbedding)
282-
or isinstance(sub_module, Int8DynActInt8WeightLinear)
283-
)
284-
# Checkpoints with SpinQuant could come with two formats for scales:
285-
# 1. scales is grouped by group size
286-
# 2. scales is not grouped by group size
287-
# We need to handle both cases here.
288-
# TODO(lunwenh): remove this once we have a unified format for scales.
289-
if isinstance(sub_module, Int8DynActInt4WeightLinear):
290-
checkpoint[new_key] = (
291-
old_val if linear_group_size == -1 else old_val[:, ::linear_group_size]
292-
)
293-
elif isinstance(sub_module, Int8DynActInt8WeightLinear):
294-
checkpoint[new_key] = old_val[:, 0]
295-
elif isinstance(sub_module, QuantizedGroupEmbedding):
296-
if (
297-
embedding_group_size is None or embedding_group_size == 0
298-
): # Scales are not grouped
299-
checkpoint[new_key] = old_val[:, 0]
300-
elif embedding_group_size == -1: # Scales are grouped by group size
301-
checkpoint[new_key] = old_val
302-
else:
303-
checkpoint[new_key] = old_val[:, ::embedding_group_size]
304-
305-
for k in keys_to_remove:
306-
checkpoint.pop(k)
307260
for k, v in checkpoint.items():
308-
checkpoint[k] = v.contiguous()
261+
checkpoint[k] = torch.squeeze(v.contiguous())

examples/models/llama2/tests/test_spinquant_transforms.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_transform_linear_for_spinquant(self):
6565
weight.to(torch.float32), n_bit, group_size, scales_precision
6666
)
6767
checkpoint[f"{fqn}.weight"] = weight_int8.to("cpu")
68-
checkpoint[f"{fqn}.scale"] = scales.to("cpu")
68+
checkpoint[f"{fqn}.scales"] = scales.to("cpu")
6969

7070
# Step 3:
7171
# Transform the model so that it is compatible with the new checkpoint
@@ -76,11 +76,7 @@ def test_transform_linear_for_spinquant(self):
7676
"8da4w",
7777
torch.float32,
7878
)
79-
sanitize_checkpoint_from_spinquant(
80-
model,
81-
checkpoint,
82-
-1,
83-
)
79+
sanitize_checkpoint_from_spinquant(checkpoint)
8480

8581
model.load_state_dict(
8682
checkpoint,
@@ -114,7 +110,7 @@ def test_transform_output_linear_for_spinquant(self):
114110
scales_dtype=torch.float32,
115111
)
116112
checkpoint[f"{fqn}.weight"] = weight_int8.to("cpu")
117-
checkpoint[f"{fqn}.scale"] = scales.to("cpu")
113+
checkpoint[f"{fqn}.scales"] = scales.to("cpu")
118114

119115
# Step 3:
120116
# Transform the model so that it is compatible with the new checkpoint
@@ -123,11 +119,7 @@ def test_transform_output_linear_for_spinquant(self):
123119
checkpoint,
124120
torch.float32,
125121
)
126-
sanitize_checkpoint_from_spinquant(
127-
model,
128-
checkpoint,
129-
-1,
130-
)
122+
sanitize_checkpoint_from_spinquant(checkpoint)
131123

132124
model.load_state_dict(
133125
checkpoint,
@@ -166,7 +158,7 @@ def test_transform_embedding_for_spinquant(self):
166158
weight.to(torch.float32), n_bit, group_size, scales_precision
167159
)
168160
checkpoint[f"{fqn}.weight"] = weight_int8.to("cpu")
169-
checkpoint[f"{fqn}.scale"] = scales.to("cpu")
161+
checkpoint[f"{fqn}.scales"] = scales.to("cpu")
170162

171163
# Step 3:
172164
# Transform the model so that it is compatible with the new checkpoint
@@ -177,12 +169,7 @@ def test_transform_embedding_for_spinquant(self):
177169
n_bit,
178170
group_size,
179171
)
180-
sanitize_checkpoint_from_spinquant(
181-
module=model,
182-
checkpoint=checkpoint,
183-
linear_group_size=-1,
184-
embedding_group_size=-1,
185-
)
172+
sanitize_checkpoint_from_spinquant(checkpoint)
186173

187174
model.load_state_dict(
188175
checkpoint,

0 commit comments

Comments
 (0)