Skip to content

Commit 77a581f

Browse files
mikekgfbmalfet
authored andcommitted
Gptq hqq (#634)
* unified quantizer for 4b * typo * typo * fix argument spec * gguf interface * GPTQ and HQQ code alignment
1 parent f79420d commit 77a581f

File tree

3 files changed

+47
-165
lines changed

3 files changed

+47
-165
lines changed

qops.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,7 @@
44
import torch.nn as nn
55
import torch.nn.functional as F
66

7-
from build.utils import (
8-
find_multiple,
9-
get_device_str,
10-
get_precision,
11-
name_to_dtype,
12-
state_dict_device,
13-
use_et_backend,
14-
)
7+
from build.utils import find_multiple, get_precision, use_et_backend
158

169
# from torch.nn.parameter import Parameter
1710

quantize.py

Lines changed: 34 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
get_precision,
2222
name_to_dtype,
2323
state_dict_device,
24-
use_et_backend,
2524
)
2625

2726
from qops import (
@@ -389,8 +388,6 @@ def quantize(self, module):
389388
# cur_state_dict = state_dict_device(self.model_.state_dict())
390389
# dict_device = "cpu" # self.device
391390

392-
device = self.device
393-
394391
if self.bitwidth == 4:
395392
range_min = -8
396393
range_max = 7
@@ -468,11 +465,6 @@ def __init__(
468465

469466
@torch.no_grad()
470467
def quantize(self, module):
471-
# cur_state_dict = state_dict_device(self.model_.state_dict())
472-
# dict_device = "cpu" # self.device
473-
474-
device = self.device
475-
476468
if self.bitwidth == 4:
477469
range_min = -8
478470
range_max = 7
@@ -544,8 +536,7 @@ def quantized_model(self) -> nn.Module:
544536
#########################################################################
545537
##### weight only int4 per channel groupwise quantized code ######
546538

547-
548-
class NewWeightOnlyInt4QuantHandler(QuantHandler):
539+
class WeightOnlyInt4QuantHandler(QuantHandler):
549540
def __init__(
550541
self,
551542
model: nn.Module,
@@ -568,11 +559,6 @@ def __init__(
568559

569560
@torch.no_grad()
570561
def quantize(self, module):
571-
# cur_state_dict = state_dict_device(self.model_.state_dict())
572-
# dict_device = "cpu" # self.device
573-
574-
device = self.device
575-
576562
for name, child in module.named_children():
577563
# print(f"name: {name}")
578564
if isinstance(child, torch.nn.Linear):
@@ -633,129 +619,6 @@ def quantized_model(self) -> nn.Module:
633619
return self.quantize(self.model_)
634620

635621

636-
def replace_linear_int4(
637-
module,
638-
device,
639-
groupsize,
640-
inner_k_tiles,
641-
padding_allowed,
642-
):
643-
for name, child in module.named_children():
644-
if isinstance(child, nn.Linear):
645-
if padding_allowed or WeightOnlyInt4Linear._check_k(
646-
k=child.in_features,
647-
groupsize=groupsize,
648-
inner_k_tiles=inner_k_tiles,
649-
):
650-
setattr(
651-
module,
652-
name,
653-
WeightOnlyInt4Linear(
654-
child.in_features,
655-
child.out_features,
656-
bias=False,
657-
device=device,
658-
groupsize=groupsize,
659-
inner_k_tiles=inner_k_tiles,
660-
),
661-
)
662-
else:
663-
replace_linear_int4(
664-
child, device, groupsize, inner_k_tiles, padding_allowed
665-
)
666-
667-
668-
class WeightOnlyInt4QuantHandler(QuantHandler):
669-
def __init__(
670-
self,
671-
model: nn.Module,
672-
device,
673-
tokenizer=None,
674-
*,
675-
groupsize=128,
676-
inner_k_tiles=8,
677-
padding_allowed=True,
678-
):
679-
self.model_ = model
680-
self.device = device
681-
self.groupsize = groupsize
682-
self.inner_k_tiles = inner_k_tiles
683-
self.padding_allowed = padding_allowed
684-
assert groupsize in [32, 64, 128, 256]
685-
assert inner_k_tiles in [2, 4, 8]
686-
687-
# @torch.no_grad()
688-
# def p(self):
689-
# cur_state_dict = state_dict_device(self.model_.state_dict())
690-
# dict_device = "cpu" # self.device
691-
#
692-
# for fqn, mod in self.model_.named_modules():
693-
# if hasattr(mod, "weight"):
694-
# print(f"device={str(mod.weight.data.device)}")
695-
696-
@torch.no_grad()
697-
def create_quantized_state_dict(self):
698-
cur_state_dict = state_dict_device(self.model_.state_dict())
699-
dict_device = "cpu" # self.device
700-
701-
for fqn, mod in self.model_.named_modules():
702-
if isinstance(mod, torch.nn.Linear):
703-
assert not mod.bias
704-
out_features = mod.out_features
705-
in_features = mod.in_features
706-
assert out_features % 8 == 0, "require out_features % 8 == 0"
707-
# print(f"linear: {fqn}, in={in_features}, out={out_features}")
708-
709-
weight = mod.weight.data
710-
if not WeightOnlyInt4Linear._check_k(
711-
k=in_features,
712-
groupsize=self.groupsize,
713-
inner_k_tiles=self.inner_k_tiles,
714-
):
715-
if self.padding_allowed:
716-
print(
717-
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
718-
)
719-
padded_in_features = find_multiple(in_features, 1024)
720-
weight = F.pad(
721-
weight, pad=(0, padded_in_features - in_features)
722-
)
723-
else:
724-
print(
725-
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
726-
+ "and that groupsize and inner_k_tiles*16 evenly divide into it"
727-
)
728-
continue
729-
weight_int4pack, scales_and_zeros = (
730-
WeightOnlyInt4Linear._prepare_weight_and_scales_and_zeros(
731-
weight.to(torch.float), self.groupsize, self.inner_k_tiles
732-
)
733-
)
734-
weight_int4pack = weight_int4pack.to(device=dict_device)
735-
scales_and_zeros = scales_and_zeros.to(device=dict_device)
736-
cur_state_dict[f"{fqn}.weight"] = weight_int4pack
737-
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros
738-
739-
return cur_state_dict
740-
741-
def convert_for_runtime(self):
742-
replace_linear_int4(
743-
self.model_,
744-
self.device,
745-
self.groupsize,
746-
self.inner_k_tiles,
747-
self.padding_allowed,
748-
)
749-
return self.model_
750-
751-
def quantized_model(self) -> nn.Module:
752-
model_updated_state_dict = self.create_quantized_state_dict()
753-
self.convert_for_runtime()
754-
self.model_.load_state_dict(model_updated_state_dict)
755-
# self.p()
756-
return self.model_
757-
758-
759622
#########################################################################
760623
##### GPTQ #####
761624

@@ -1011,13 +874,35 @@ def make_names_and_values_dict_func(q, qparams):
1011874
self.make_names_and_values_dict_func = make_names_and_values_dict_func
1012875
super().__init__()
1013876

877+
def replace_linear_int4(
878+
self,
879+
module,
880+
):
881+
for name, child in module.named_children():
882+
if isinstance(child, nn.Linear):
883+
if self.padding_allowed or WeightOnlyInt4Linear._check_k(
884+
k=child.in_features,
885+
groupsize=self.groupsize,
886+
inner_k_tiles=self.inner_k_tiles,
887+
):
888+
setattr(
889+
module,
890+
name,
891+
WeightOnlyInt4Linear(
892+
child.in_features,
893+
child.out_features,
894+
bias=False,
895+
device=self.device,
896+
groupsize=self.groupsize,
897+
inner_k_tiles=self.inner_k_tiles,
898+
),
899+
)
900+
else:
901+
self.replace_linear_int4(child)
902+
1014903
def convert_for_runtime(self):
1015-
replace_linear_int4(
904+
self.replace_linear_int4(
1016905
self.model_,
1017-
self.device,
1018-
self.groupsize,
1019-
self.inner_k_tiles,
1020-
self.padding_allowed,
1021906
)
1022907
return self.model_
1023908

@@ -1048,7 +933,8 @@ def __init__(self, model: nn.Module, device, tokenizer=None, *, groupsize):
1048933
self.device = device
1049934
self.groupsize = groupsize
1050935

1051-
def create_quantized_state_dict(self):
936+
@torch.no_grad()
937+
def quantize(self, module):
1052938
from hqq.core.quantize import Quantizer
1053939

1054940
for m in self.model_.modules():
@@ -1066,20 +952,11 @@ def create_quantized_state_dict(self):
1066952
)
1067953

1068954
return WeightOnlyInt4QuantHandler(
1069-
self.model_, self.device, groupsize=self.groupsize
1070-
).create_quantized_state_dict()
1071-
1072-
def convert_for_runtime(self):
1073-
# ALSO: all code must work for CPU, CUDA, MPS
1074-
return WeightOnlyInt4GPTQQuantHandler(
1075-
self.model_, self.device, tokenizer=None, groupsize=self.groupsize
1076-
).convert_for_runtime()
955+
model=self.model_, device=self.device, groupsize=self.groupsize
956+
).quantize(self.model_)
1077957

1078958
def quantized_model(self) -> nn.Module:
1079-
model_updated_state_dict = self.create_quantized_state_dict()
1080-
self.convert_for_runtime()
1081-
self.model_.load_state_dict(model_updated_state_dict)
1082-
return self.model_
959+
return self.quantize(self.model_)
1083960

1084961

1085962
##########################################################################
@@ -1091,7 +968,7 @@ def quantized_model(self) -> nn.Module:
1091968
quantizer_class_dict = {
1092969
"embedding": EmbeddingOnlyQuantHandler,
1093970
"linear:int8": WeightOnlyInt8QuantHandler,
1094-
"linear:int4": NewWeightOnlyInt4QuantHandler,
971+
"linear:int4": WeightOnlyInt4QuantHandler,
1095972
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
1096973
"linear:int4-gptq": WeightOnlyInt4GPTQQuantHandler,
1097974
"linear:hqq": WeightOnlyInt4HqqQuantHandler,

scripts/process-readme.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
def print_between_triple_backticks(filename):
2+
with open(filename, "r") as file:
3+
lines = file.readlines()
4+
print_flag = False
5+
for line in lines:
6+
if line.startswith("```"):
7+
print_flag = not print_flag
8+
elif print_flag:
9+
print(line, end="")
10+
11+
12+
print_between_triple_backticks("README.md")

0 commit comments

Comments
 (0)