Skip to content

Commit 78ce089

Browse files
andrewor14facebook-github-bot
authored andcommitted
Support fp32 activation for quantizing llama2 to int8 activation and int4 weight (#2032)
Summary: Pull Request resolved: #2032 Previously we only supported quantizing fp16 activations to int8. This adds support for quantizing fp32 activations as well to enable testing. Representation: https://www.internalfb.com/intern/everpaste/?handle=GAoWXBlf5O8T4TMBAP8Cps4UiVx7bsIXAAAz Reviewed By: digantdesai Differential Revision: D54032932 fbshipit-source-id: baaddbd385985240689444041c5de33245c86dcf
1 parent 33306d3 commit 78ce089

File tree

3 files changed

+94
-35
lines changed

3 files changed

+94
-35
lines changed

examples/models/llama2/builder.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ class DType(Enum):
4646
fp32 = "fp32"
4747
fp16 = "fp16"
4848

49+
def to_torch_dtype(self) -> torch.dtype:
50+
mapping = {
51+
DType.fp32: torch.float32,
52+
DType.fp16: torch.float16,
53+
}
54+
if self not in mapping:
55+
raise ValueError(f"Unsupported dtype {self}")
56+
return mapping[self]
57+
4958

5059
def load_llama_model(
5160
*,
@@ -145,13 +154,10 @@ def to_dtype(self, dtype_override: Optional[DType]) -> "LlamaEdgeManager":
145154
assert not dtype_override or isinstance(
146155
dtype_override, DType
147156
), "Override dtype needs to be of type <DType>"
148-
if dtype_override == DType.fp16 and self.dtype != DType.fp16:
149-
logging.info("model.to torch.float16")
150-
self.model = self.model.to(dtype=torch.float16)
151-
self.dtype = dtype_override
152-
elif dtype_override == DType.fp32 and self.dtype != DType.fp32:
153-
logging.info("model.to torch.float32")
154-
self.model = self.model.to(dtype=torch.float32)
157+
if dtype_override is not None and dtype_override != self.dtype:
158+
torch_dtype = dtype_override.to_torch_dtype()
159+
logging.info(f"model.to {torch_dtype}")
160+
self.model = self.model.to(dtype=torch_dtype)
155161
self.dtype = dtype_override
156162
return self
157163

examples/models/llama2/export_llama_lib.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import shlex
1212
from functools import partial
1313
from pathlib import Path
14-
from typing import List
14+
from typing import List, Optional
1515

1616
import pkg_resources
1717
import torch
@@ -94,7 +94,11 @@ def check_embedding_byte_registered():
9494
return quantizers
9595

9696

97-
def quantize(model: torch.nn.Module, qmode: str) -> torch.nn.Module:
97+
def quantize(
98+
model: torch.nn.Module,
99+
qmode: str,
100+
activation_dtype: Optional[DType],
101+
) -> torch.nn.Module:
98102
"""
99103
Quantizes a model by converting all weights to int8.
100104
Args:
@@ -103,14 +107,21 @@ def quantize(model: torch.nn.Module, qmode: str) -> torch.nn.Module:
103107
Returns:
104108
A quantized model.
105109
"""
110+
if activation_dtype is not None:
111+
torch_dtype = activation_dtype.to_torch_dtype()
112+
else:
113+
torch_dtype = torch.float16
114+
106115
if qmode == "int8":
107116
model_int8 = WeightOnlyInt8QuantHandler(model)
108117
model_int8_state_dict = model_int8.create_quantized_state_dict()
109118
model_int8 = model_int8.convert_for_runtime()
110119
model_int8.load_state_dict(model_int8_state_dict)
111120
return model_int8
112121
elif qmode == "int4":
113-
model_int4 = Int8DynActInt4WeightQuantHandler(model)
122+
model_int4 = Int8DynActInt4WeightQuantHandler(
123+
model, activation_precision=torch_dtype
124+
)
114125
model_int4_state_dict = model_int4.create_quantized_state_dict()
115126
model_int4 = model_int4.convert_for_runtime()
116127
print("quantized model:", model_int4)
@@ -269,28 +280,29 @@ def _export_llama(modelname, args) -> str: # noqa: C901
269280
output_dir_path = canonical_path(args.output_dir, dir=True)
270281
modelname = "llama2"
271282
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
283+
284+
# dtype override
285+
if args.dtype_override is not None:
286+
dtype_override = DType[args.dtype_override]
287+
else:
288+
dtype_override = DType["fp16"] if args.quantization_mode == "int4" else None
289+
272290
# source transforms
273291
transforms = []
274292
if args.quantized_ckpt or args.quantization_mode:
275293
modelname = f"{modelname}_q"
276-
transforms.append(partial(quantize, qmode=args.quantization_mode))
294+
transforms.append(
295+
partial(
296+
quantize, qmode=args.quantization_mode, activation_dtype=dtype_override
297+
)
298+
)
277299

278300
if args.embedding_quantize:
279301
modelname = f"{modelname}_e"
280302
transforms.append(
281303
lambda model: EmbeddingOnlyInt8QuantHandler(model).convert_for_runtime()
282304
)
283305

284-
# dtype override
285-
if args.dtype_override:
286-
override = (
287-
DType["fp16"]
288-
if args.quantization_mode == "int4"
289-
else DType[args.dtype_override]
290-
)
291-
else:
292-
override = None
293-
294306
# export_to_edge
295307
quantizers = get_pt2e_quantizers(args)
296308

@@ -323,7 +335,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901
323335
.set_output_dir(output_dir_path)
324336
.set_metadata(args.metadata)
325337
.source_transform(transforms)
326-
.to_dtype(override)
338+
.to_dtype(dtype_override)
327339
.export_to_edge(quantizers)
328340
.to_backend(partitioners)
329341
.to_executorch()

examples/models/llama2/quantize.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,14 @@ def _calc_padded_size_linear_int4(k, groupsize=1, inner_k_tiles=1):
791791
return find_multiple(k, groupsize, inner_k_tiles * 16)
792792

793793

794-
def replace_linear_8da4w(module, group_size, inner_k_tiles, padding_allowed):
794+
def replace_linear_8da4w(
795+
module,
796+
group_size,
797+
inner_k_tiles,
798+
padding_allowed,
799+
activation_precision,
800+
weight_precision,
801+
):
795802
for name, child in module.named_children():
796803
if isinstance(child, nn.Linear):
797804
if (
@@ -807,20 +814,37 @@ def replace_linear_8da4w(module, group_size, inner_k_tiles, padding_allowed):
807814
bias=False,
808815
group_size=group_size,
809816
inner_k_tiles=inner_k_tiles,
817+
activation_precision=activation_precision,
818+
weight_precision=weight_precision,
810819
),
811820
)
812821
else:
813-
replace_linear_8da4w(child, group_size, inner_k_tiles, padding_allowed)
822+
replace_linear_8da4w(
823+
child,
824+
group_size,
825+
inner_k_tiles,
826+
padding_allowed,
827+
activation_precision,
828+
weight_precision,
829+
)
814830

815831

816832
class Int8DynActInt4WeightQuantHandler:
817-
def __init__(self, mod, group_size=128, inner_k_tiles=8, padding_allowed=True):
833+
def __init__(
834+
self,
835+
mod,
836+
group_size=128,
837+
inner_k_tiles=8,
838+
padding_allowed=True,
839+
activation_precision=torch.float16,
840+
weight_precision=torch.float16,
841+
):
818842
self.mod = mod
819843
self.group_size = group_size
820844
self.inner_k_tiles = inner_k_tiles
821845
self.padding_allowed = padding_allowed
822-
# TODO: make this an argument
823-
self.precision = torch.float16
846+
self.activation_precision = activation_precision
847+
self.weight_precision = weight_precision
824848
assert group_size in [32, 64, 128, 256]
825849
assert inner_k_tiles in [2, 4, 8]
826850

@@ -861,7 +885,9 @@ def create_quantized_state_dict(self):
861885
weight_int4pack,
862886
scales_and_zeros,
863887
) = prepare_int4_weight_and_scales_and_zeros(
864-
weight.to(self.precision), self.group_size, self.inner_k_tiles
888+
weight.to(self.weight_precision),
889+
self.group_size,
890+
self.inner_k_tiles,
865891
)
866892
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
867893
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
@@ -870,7 +896,12 @@ def create_quantized_state_dict(self):
870896

871897
def convert_for_runtime(self):
872898
replace_linear_8da4w(
873-
self.mod, self.group_size, self.inner_k_tiles, self.padding_allowed
899+
self.mod,
900+
self.group_size,
901+
self.inner_k_tiles,
902+
self.padding_allowed,
903+
self.activation_precision,
904+
self.weight_precision,
874905
)
875906
return self.mod
876907

@@ -891,6 +922,8 @@ def __init__(
891922
dtype=None,
892923
group_size: int = 128,
893924
inner_k_tiles: int = 8,
925+
activation_precision: torch.dtype = torch.float16,
926+
weight_precision: torch.dtype = torch.float16,
894927
) -> None:
895928
super().__init__()
896929
# always pad if needed since it becomes a noop at runtime if not needed
@@ -903,7 +936,8 @@ def __init__(
903936
assert not bias, "require bias=False"
904937
self.group_size = group_size
905938
self.inner_k_tiles = inner_k_tiles
906-
self.precision = torch.float16
939+
self.weight_precision = weight_precision
940+
self.activation_precision = activation_precision
907941

908942
# assert out_features % 8 == 0, "require out_features % 8 == 0"
909943
assert (
@@ -917,12 +951,13 @@ def __init__(
917951
self.register_buffer(
918952
"scales_and_zeros",
919953
torch.empty(
920-
(in_features // group_size, out_features, 2), dtype=self.precision
954+
(in_features // group_size, out_features, 2),
955+
dtype=self.weight_precision,
921956
),
922957
)
923958

924959
def forward(self, input: torch.Tensor) -> torch.Tensor:
925-
input = input.to(self.precision)
960+
input = input.to(self.activation_precision)
926961
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
927962

928963
(
@@ -937,15 +972,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
937972
input, scales, zero_points, quant_min, quant_max, torch.int8
938973
)
939974
input = torch.ops.quantized_decomposed.dequantize_per_token(
940-
input, scales, zero_points, quant_min, quant_max, torch.int8, self.precision
975+
input,
976+
scales,
977+
zero_points,
978+
quant_min,
979+
quant_max,
980+
torch.int8,
981+
self.activation_precision,
941982
)
942983

943-
input = input.to(self.precision)
984+
input = input.to(self.activation_precision)
944985
return linear_forward_int4(
945986
input,
946987
self.weight,
947988
self.scales_and_zeros,
948989
self.out_features,
949990
self.group_size,
950-
self.precision,
991+
self.weight_precision,
951992
)

0 commit comments

Comments
 (0)