Skip to content

Commit 3e79ea4

Browse files
Lunwen Hefacebook-github-bot
authored andcommitted
Transform embedding from SpinQuant checkpoint (#5552)
Summary: Pull Request resolved: #5552 This diff updates the llama export part to be able to load SpinQuant checkpoint which has all linear layers and embedding table quantized. Reviewed By: mergennachin Differential Revision: D62665632 fbshipit-source-id: 3e8cda37ac16b65543e3123ea59526352ac6a70c
1 parent 72245c3 commit 3e79ea4

File tree

4 files changed

+275
-59
lines changed

4 files changed

+275
-59
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 80 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,28 @@ def build_args_parser() -> argparse.ArgumentParser:
370370
help="Use SpinQuant for better quantization performance. Only support cuda and native.",
371371
)
372372

373+
parser.add_argument(
374+
"--spin_qmode",
375+
type=str,
376+
default=None,
377+
choices=["8da4w"],
378+
help="Quantization mode for SpinQuant. Only support 8da4w right now.",
379+
)
380+
381+
parser.add_argument(
382+
"--spin_group_size",
383+
type=int,
384+
default=32,
385+
help="group_size for SpinQuant weight quantization",
386+
)
387+
388+
parser.add_argument(
389+
"--spin_embedding_quantize",
390+
default="8,0",
391+
type=str,
392+
help="type of embedding quantization for SpinQuant, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
393+
)
394+
373395
parser.add_argument(
374396
"--output_prune_map",
375397
default=None,
@@ -466,10 +488,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
466488
max_seq_len=args.max_seq_length,
467489
output_prune_map_path=args.output_prune_map,
468490
metadata_str=args.metadata,
491+
dtype_override=dtype_override,
469492
args=args,
470493
)
471494
.set_output_dir(output_dir_path)
472-
.to_dtype(dtype_override)
473495
.source_transform(_get_source_transforms(modelname, dtype_override, args))
474496
)
475497

@@ -691,6 +713,7 @@ def _load_llama_model(
691713
max_seq_len: int = 128,
692714
output_prune_map_path: Optional[str] = None,
693715
metadata_str: Optional[str] = None,
716+
dtype_override: Optional[DType] = None,
694717
args,
695718
) -> "LLMEdgeManager":
696719
"""
@@ -720,23 +743,32 @@ def _load_llama_model(
720743
output_prune_map_path=output_prune_map_path,
721744
args=args,
722745
)
723-
state_dict = model.state_dict()
724-
dtype = state_dict[next(iter(state_dict))].dtype
725-
assert dtype in [
726-
torch.bfloat16,
727-
torch.float16,
728-
torch.float32,
729-
], f"Only support bfloat16, fp16 or fp32 got {dtype}"
730-
logging.info(f"Loaded model with dtype={dtype}")
731-
732-
if dtype == torch.bfloat16:
733-
dtype = DType.bf16
734-
elif dtype == torch.float16:
735-
dtype = DType.fp16
736-
elif dtype == torch.float32:
737-
dtype = DType.fp32
746+
if dtype_override:
747+
assert isinstance(
748+
dtype_override, DType
749+
), "Override dtype needs to be of type <DType>"
750+
torch_dtype = dtype_override.to_torch_dtype()
751+
logging.info(f"model.to {torch_dtype}")
752+
model = model.to(dtype=torch_dtype)
753+
dtype = dtype_override
738754
else:
739-
raise ValueError(f"Unsupported dtype {dtype}")
755+
state_dict = model.state_dict()
756+
dtype = state_dict[next(iter(state_dict))].dtype
757+
assert dtype in [
758+
torch.bfloat16,
759+
torch.float16,
760+
torch.float32,
761+
], f"Only support bfloat16, fp16 or fp32 got {dtype}"
762+
logging.info(f"Loaded model with dtype={dtype}")
763+
764+
if dtype == torch.bfloat16:
765+
dtype = DType.bf16
766+
elif dtype == torch.float16:
767+
dtype = DType.fp16
768+
elif dtype == torch.float32:
769+
dtype = DType.fp32
770+
else:
771+
raise ValueError(f"Unsupported dtype {dtype}")
740772

741773
return LLMEdgeManager(
742774
model=model,
@@ -769,21 +801,9 @@ def _get_source_transforms( # noqa
769801
modelname: str, dtype_override: Optional[DType], args
770802
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
771803
transforms = []
772-
if args.quantization_mode:
773-
modelname = f"{modelname}_q"
774-
if args.use_spin_quant is None:
775-
transforms.append(
776-
get_quant_weight_transform(args, dtype_override, verbose_export())
777-
)
778-
# For SpinQuant, the checkpoints are already quantized
779-
# aka the weights have corresponding scales value,
780-
# So that means, we don't need to apply quantization
781-
# transform. However, we will still need to apply
782-
# transformations that change the model structure to
783-
# match the checkpoint format.
784-
# transform_for_spinquant() will apply these transformations
785-
# later in model.py file.
786-
elif args.use_spin_quant == "cuda":
804+
805+
if args.use_spin_quant:
806+
if args.use_spin_quant == "cuda":
787807
from .source_transformation.spin_quant import (
788808
inject_fast_hadamard_transform_cuda_for_spin_quant,
789809
)
@@ -796,7 +816,35 @@ def _get_source_transforms( # noqa
796816

797817
transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)
798818

819+
if args.quantization_mode:
820+
"""
821+
When this option is selected, it finds all linear layers and transforms
822+
into quantized linear equivalent module.
823+
824+
There are cases where the checkpoint is already quantized, for example
825+
on use_spin_quant is enabled. In that case, it will do the appropriate
826+
transformations based on the given checkpoint first. In those cases,
827+
if quantization_mode is enabled, it will quantize any remaining linear
828+
ops that is not quantized.
829+
830+
There are cases where this may be a no-op, namely, if all linears are
831+
quantized in the checkpoint.
832+
"""
833+
modelname = f"{modelname}_q"
834+
transforms.append(
835+
get_quant_weight_transform(args, dtype_override, verbose_export())
836+
)
837+
799838
if args.embedding_quantize:
839+
"""
840+
When this option is selected, it finds all embedding layers and transforms
841+
into quantized embedding equivalent module.
842+
843+
There are cases where the checkpoint is already quantized, for example
844+
on use_spin_quant is enabled. In that case, it will do the appropriate
845+
transformations based on the given checkpoint first. In those cases,
846+
this wil be a no-op.
847+
"""
800848
modelname = f"{modelname}_e"
801849
transforms.append(get_quant_embedding_transform(args))
802850

examples/models/llama2/model.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -191,16 +191,16 @@ def __init__(self, **kwargs):
191191
)
192192
elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
193193
print("Using SPIN quantization.")
194-
assert hasattr(self.args, "group_size"), "group_size must be specified"
194+
assert hasattr(self.args, "spin_qmode"), "spin_qmode must be specified"
195195
assert hasattr(
196-
self.args, "quantization_mode"
197-
), "quantization_mode must be specified"
196+
self.args, "spin_group_size"
197+
), "spin_group_size must be specified"
198198
assert hasattr(
199199
self.args, "dtype_override"
200200
), "dtype_override must be specified"
201201
from .source_transformation.spin_quant import (
202202
sanitize_checkpoint_from_spinquant,
203-
transform_for_spinquant,
203+
transform_linear_for_spinquant,
204204
)
205205

206206
mapping = {
@@ -209,17 +209,45 @@ def __init__(self, **kwargs):
209209
"bf16": torch.bfloat16,
210210
}
211211

212-
self.model_ = transform_for_spinquant(
212+
self.model_ = transform_linear_for_spinquant(
213213
self.model_,
214214
checkpoint,
215-
self.args.group_size,
216-
self.args.quantization_mode,
215+
self.args.spin_group_size,
216+
self.args.spin_qmode,
217217
mapping[self.args.dtype_override],
218218
)
219219

220+
embedding_bit_width, embedding_group_size = None, None
221+
if hasattr(self.args, "spin_embedding_quantize"):
222+
embedding_bit_width, embedding_group_size = (
223+
self.args.spin_embedding_quantize.split(",")
224+
)
225+
from .source_transformation.spin_quant import (
226+
transform_embedding_for_spinquant,
227+
)
228+
229+
if (
230+
embedding_group_size == "none"
231+
or embedding_group_size == "None"
232+
or embedding_group_size == "0"
233+
):
234+
embedding_group_size = None
235+
else:
236+
embedding_group_size = int(embedding_group_size)
237+
238+
self.model_ = transform_embedding_for_spinquant(
239+
self.model_,
240+
checkpoint,
241+
mapping[self.args.dtype_override],
242+
int(embedding_bit_width),
243+
embedding_group_size,
244+
)
245+
220246
sanitize_checkpoint_from_spinquant(
221-
checkpoint,
222-
self.args.group_size,
247+
module=self.model_,
248+
checkpoint=checkpoint,
249+
linear_group_size=self.args.spin_group_size,
250+
embedding_group_size=embedding_group_size,
223251
)
224252

225253
# assign=True: load params/buffers by assignment instead of performing an in-place copy.

examples/models/llama2/source_transformation/spin_quant.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# Helper functions for tranforming the model to be able to run SpinQuant.
1010
# See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant.
1111

12-
from typing import Any
12+
from typing import Any, Optional
1313

1414
import torch
1515

@@ -20,6 +20,8 @@
2020
from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear
2121
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
2222

23+
from .quantize import QuantizedGroupEmbedding
24+
2325

2426
def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module):
2527
"""
@@ -123,7 +125,7 @@ def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
123125
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
124126

125127

126-
def transform_for_spinquant(
128+
def transform_linear_for_spinquant(
127129
module: torch.nn.Module,
128130
checkpoint: Any,
129131
group_size: int,
@@ -151,9 +153,64 @@ def transform_for_spinquant(
151153
return module
152154

153155

156+
def _replace_embedding_with_quantized_group_embedding_for_spinquant(
157+
module: torch.nn.Module,
158+
checkpoint: Any,
159+
dtype: torch.dtype,
160+
bit_width: int,
161+
group_size: Optional[int] = None,
162+
):
163+
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
164+
# Only replace embedding layers where the checkpoint contains explicit scales
165+
scales_key = f"{cur_fqn}.scale"
166+
if isinstance(child, nn.Embedding) and scales_key in checkpoint:
167+
assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
168+
assert checkpoint[scales_key].dtype == torch.float32
169+
return True
170+
return False
171+
172+
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
173+
new_embedding = QuantizedGroupEmbedding(
174+
device=child.weight.device,
175+
vocab_size=child.weight.shape[0],
176+
embedding_dim=child.weight.shape[1],
177+
group_size=group_size,
178+
dtype=dtype,
179+
packed=False, # TODO(lunwenh): support packed embedding for SpinQuant
180+
)
181+
return new_embedding
182+
183+
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
184+
185+
186+
def transform_embedding_for_spinquant(
187+
module: torch.nn.Module,
188+
checkpoint: Any,
189+
dtype: torch.dtype,
190+
bit_width: int,
191+
group_size: Optional[int] = None,
192+
) -> torch.nn.Module:
193+
"""
194+
Transform the model to be able to load SpinQuant checkpoints that
195+
are quantized with the given bit_width and group size for embedding.
196+
"""
197+
if group_size is not None and group_size not in [0, 32, 64, 128, 256]:
198+
raise ValueError(f"Group size {group_size} is not supported for SpinQuant.")
199+
_replace_embedding_with_quantized_group_embedding_for_spinquant(
200+
module,
201+
checkpoint,
202+
dtype,
203+
bit_width,
204+
group_size,
205+
)
206+
return module
207+
208+
154209
def sanitize_checkpoint_from_spinquant(
210+
module: torch.nn.Module,
155211
checkpoint: Any,
156-
group_size: int,
212+
linear_group_size: int,
213+
embedding_group_size: Optional[int] = None,
157214
):
158215
"""
159216
Sanitize the SpinQuant checkpoint.
@@ -173,7 +230,31 @@ def sanitize_checkpoint_from_spinquant(
173230

174231
for old_key, new_key in keys_to_rename:
175232
old_val = checkpoint.pop(old_key)
176-
checkpoint[new_key] = old_val if group_size == -1 else old_val[:, ::group_size]
233+
module_name = new_key[0 : new_key.rfind(".")]
234+
sub_module = module.get_submodule(module_name)
235+
assert sub_module is not None
236+
assert isinstance(sub_module, Int8DynActInt4WeightLinear) or isinstance(
237+
sub_module, QuantizedGroupEmbedding
238+
)
239+
# Checkpoints with SpinQuant could come with two formats for scales:
240+
# 1. scales is grouped by group size
241+
# 2. scales is not grouped by group size
242+
# We need to handle both cases here.
243+
# TODO(lunwenh): remove this once we have a unified format for scales.
244+
if isinstance(sub_module, Int8DynActInt4WeightLinear):
245+
checkpoint[new_key] = (
246+
old_val if linear_group_size == -1 else old_val[:, ::linear_group_size]
247+
)
248+
elif isinstance(sub_module, QuantizedGroupEmbedding):
249+
if (
250+
embedding_group_size is None or embedding_group_size == 0
251+
): # Scales are not grouped
252+
checkpoint[new_key] = old_val[:, 0]
253+
elif embedding_group_size == -1: # Scales are grouped by group size
254+
checkpoint[new_key] = old_val
255+
else:
256+
checkpoint[new_key] = old_val[:, ::embedding_group_size]
257+
177258
for k in keys_to_remove:
178259
checkpoint.pop(k)
179260
for k, v in checkpoint.items():

0 commit comments

Comments
 (0)