Skip to content

Commit f85339a

Browse files
committed
Only migrate 8a4w
1 parent 84981d9 commit f85339a

File tree

1 file changed

+100
-15
lines changed

1 file changed

+100
-15
lines changed

quantization/quantize.py

Lines changed: 100 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,23 @@
1515
import torch
1616
import torch.nn as nn
1717
import torch.nn.functional as F
18-
from build.utils import get_device_str, get_precision, name_to_dtype, state_dict_device
18+
from build.utils import (
19+
find_multiple,
20+
get_device_str,
21+
get_precision,
22+
name_to_dtype,
23+
state_dict_device,
24+
)
25+
26+
from quantization.qops import (
27+
LinearInt4 as WeightOnlyInt4Linear,
28+
LinearInt8 as WeightOnlyInt8Linear,
29+
QuantizedEmbedding,
30+
)
1931

20-
from quantization.qops import LinearInt8 as WeightOnlyInt8Linear, QuantizedEmbedding
2132
# AttributeError: '_OpNamespace' 'quantized_decomposed' object has no attribute 'quantize_per_channel_group'
2233
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa
23-
from torchao.quantization.quant_api import (
24-
quantize_,
25-
int4_weight_only,
26-
Int4WeightOnlyQuantizer,
27-
Int8DynActInt4WeightQuantizer,
28-
)
34+
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
2935

3036

3137
#########################################################################
@@ -60,12 +66,6 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
6066
else:
6167
precision = get_precision()
6268

63-
# Only use quant API for dtype bf16 and CUDA
64-
if quantizer == "linear:int4" and precision == torch.bfloat16 and device == "cuda":
65-
quantize_(model, int4_weight_only(group_size=q_kwargs["groupsize"]))
66-
model.to(device="cuda")
67-
continue
68-
6969
try:
7070
# Easier to ask forgiveness than permission
7171
quant_handler = ao_quantizer_class_dict[quantizer](
@@ -540,6 +540,91 @@ def quantized_model(self) -> nn.Module:
540540
return self.quantize(self.model_)
541541

542542

543+
#########################################################################
544+
##### weight only int4 per channel groupwise quantized code ######
545+
546+
547+
class WeightOnlyInt4QuantHandler(QuantHandler):
548+
def __init__(
549+
self,
550+
model: nn.Module,
551+
device=None,
552+
*,
553+
tokenizer=None,
554+
groupsize=128,
555+
inner_k_tiles=8,
556+
padding_allowed=True,
557+
):
558+
self.model_ = model
559+
self.device = device
560+
self.groupsize = groupsize
561+
self.inner_k_tiles = inner_k_tiles
562+
self.padding_allowed = padding_allowed
563+
assert groupsize in [32, 64, 128, 256]
564+
assert inner_k_tiles in [2, 4, 8]
565+
566+
@torch.no_grad()
567+
def quantize(self, module):
568+
for name, child in module.named_children():
569+
# print(f"name: {name}")
570+
if isinstance(child, torch.nn.Linear):
571+
assert not child.bias
572+
out_features = child.out_features
573+
in_features = child.in_features
574+
assert out_features % 8 == 0, "require out_features % 8 == 0"
575+
# print(f"linear: {fqn}, in={in_features}, out={out_features}")
576+
577+
weight = child.weight.data
578+
if not WeightOnlyInt4Linear._check_k(
579+
k=in_features,
580+
groupsize=self.groupsize,
581+
inner_k_tiles=self.inner_k_tiles,
582+
):
583+
if self.padding_allowed:
584+
# print(
585+
# f"warning: {name} is padded to satisfy in_features % 1024 == 0"
586+
# )
587+
padded_in_features = find_multiple(in_features, 1024)
588+
weight = F.pad(
589+
weight, pad=(0, padded_in_features - in_features)
590+
)
591+
else:
592+
print(
593+
f"warning: {name} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
594+
+ "and that groupsize and inner_k_tiles*16 evenly divide into it"
595+
)
596+
continue
597+
weight_int4pack, scales_and_zeros = (
598+
WeightOnlyInt4Linear._prepare_weight_and_scales_and_zeros(
599+
weight.to(torch.float), self.groupsize, self.inner_k_tiles
600+
)
601+
)
602+
weight_int4pack = weight_int4pack.to(device=self.device)
603+
scales_and_zeros = scales_and_zeros.to(device=self.device)
604+
605+
setattr(
606+
module,
607+
name,
608+
WeightOnlyInt4Linear(
609+
child.in_features,
610+
child.out_features,
611+
bias=False,
612+
device=self.device,
613+
groupsize=self.groupsize,
614+
inner_k_tiles=self.inner_k_tiles,
615+
weight=weight_int4pack,
616+
scales_and_zeros=scales_and_zeros,
617+
),
618+
)
619+
else:
620+
self.quantize(child)
621+
622+
return module
623+
624+
def quantized_model(self) -> nn.Module:
625+
return self.quantize(self.model_)
626+
627+
543628
##########################################################################
544629
### quantization dictionary ###
545630

@@ -549,11 +634,11 @@ def quantized_model(self) -> nn.Module:
549634
quantizer_class_dict = {
550635
"embedding": EmbeddingOnlyQuantHandler,
551636
"linear:int8": WeightOnlyInt8QuantHandler,
637+
"linear:int4": WeightOnlyInt4QuantHandler,
552638
"precision": PrecisionHandler,
553639
"executor": ExecutorHandler,
554640
}
555641

556642
ao_quantizer_class_dict = {
557-
"linear:int4": Int4WeightOnlyQuantizer,
558643
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
559644
}

0 commit comments

Comments
 (0)