Skip to content

Commit 785e92a

Browse files
mikekgfbmalfet
authored andcommitted
unified quantizer for 4b (#625)
* unified quantizer for 4b * typo * typo * fix argument spec * gguf interface
1 parent 7b4f056 commit 785e92a

File tree

3 files changed

+117
-18
lines changed

3 files changed

+117
-18
lines changed

build/gguf_loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,10 @@ def load_model_and_state_dict(
181181
parent,
182182
_fqn_last(fqn),
183183
WeightOnlyInt4Linear(
184-
"meta",
185-
in_features,
186-
out_features,
184+
in_features=in_features,
185+
out_features=out_features,
187186
bias=False,
187+
device="meta",
188188
groupsize=Q4_0.groupsize,
189189
inner_k_tiles=inner_k_tiles,
190190
),

qops.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -323,13 +323,16 @@ class LinearInt4(torch.nn.Module):
323323

324324
def __init__(
325325
self,
326-
device: str,
327326
in_features: int,
328327
out_features: int,
329328
bias=True,
329+
device=None,
330330
dtype=None,
331+
*,
331332
groupsize: int = 128,
332333
inner_k_tiles: int = 8,
334+
weight: Optional[torch.Tensor] = None,
335+
scales_and_zeros: Optional[torch.Tensor] = None,
333336
) -> None:
334337
super().__init__()
335338
self.padding = not self._check_k(
@@ -351,9 +354,12 @@ def __init__(
351354
assert (
352355
in_features % (inner_k_tiles * 16) == 0
353356
), "require in_features % (innerKTiles * 16) == 0"
354-
self.register_buffer(
355-
"weight",
356-
torch.empty(
357+
assert (weight is None) == bool(
358+
scales_and_zeros is None
359+
), "must specify both weights and scales_and_zeros, or neither"
360+
361+
if weight is None:
362+
weight = torch.empty(
357363
(
358364
out_features // 8,
359365
in_features // (inner_k_tiles * 16),
@@ -362,15 +368,20 @@ def __init__(
362368
),
363369
dtype=torch.int32,
364370
device=device,
365-
),
366-
)
367-
self.register_buffer(
368-
"scales_and_zeros",
369-
torch.empty(
371+
)
372+
scales_and_zeros = torch.empty(
370373
(in_features // groupsize, out_features, 2),
371374
dtype=get_precision(),
372375
device=device,
373-
),
376+
)
377+
378+
self.register_buffer(
379+
"weight",
380+
weight,
381+
)
382+
self.register_buffer(
383+
"scales_and_zeros",
384+
scales_and_zeros,
374385
)
375386

376387
def forward(self, input: torch.Tensor) -> torch.Tensor:

quantize.py

Lines changed: 93 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
5454
raise RuntimeError(f"unknown quantizer {quantizer} specified")
5555

5656
model = quantizer_class_dict[quantizer](
57-
model, device, tokenizer, **q_kwargs
57+
model, device=device, tokenizer=tokenizer, **q_kwargs
5858
).quantized_model()
5959

6060

@@ -450,7 +450,7 @@ def quantized_model(self) -> nn.Module:
450450
##### embedding table quantization ######
451451

452452

453-
class EmbeddingOnlyInt8QuantHandler(QuantHandler):
453+
class EmbeddingOnlyQuantHandler(QuantHandler):
454454
def __init__(
455455
self,
456456
model: nn.Module,
@@ -545,6 +545,94 @@ def quantized_model(self) -> nn.Module:
545545
##### weight only int4 per channel groupwise quantized code ######
546546

547547

548+
class NewWeightOnlyInt4QuantHandler(QuantHandler):
549+
def __init__(
550+
self,
551+
model: nn.Module,
552+
device=None,
553+
*,
554+
tokenizer=None,
555+
groupsize=128,
556+
inner_k_tiles=8,
557+
padding_allowed=True,
558+
weight: Optional[torch.Tensor] = None,
559+
scales_and_zeros: Optional[torch.Tensor] = None,
560+
):
561+
self.model_ = model
562+
self.device = device
563+
self.groupsize = groupsize
564+
self.inner_k_tiles = inner_k_tiles
565+
self.padding_allowed = padding_allowed
566+
assert groupsize in [32, 64, 128, 256]
567+
assert inner_k_tiles in [2, 4, 8]
568+
569+
@torch.no_grad()
570+
def quantize(self, module):
571+
# cur_state_dict = state_dict_device(self.model_.state_dict())
572+
# dict_device = "cpu" # self.device
573+
574+
device = self.device
575+
576+
for name, child in module.named_children():
577+
# print(f"name: {name}")
578+
if isinstance(child, torch.nn.Linear):
579+
assert not child.bias
580+
out_features = child.out_features
581+
in_features = child.in_features
582+
assert out_features % 8 == 0, "require out_features % 8 == 0"
583+
# print(f"linear: {fqn}, in={in_features}, out={out_features}")
584+
585+
weight = child.weight.data
586+
if not WeightOnlyInt4Linear._check_k(
587+
k=in_features,
588+
groupsize=self.groupsize,
589+
inner_k_tiles=self.inner_k_tiles,
590+
):
591+
if self.padding_allowed:
592+
print(
593+
f"warning: {name} is padded to satisfy in_features % 1024 == 0"
594+
)
595+
padded_in_features = find_multiple(in_features, 1024)
596+
weight = F.pad(
597+
weight, pad=(0, padded_in_features - in_features)
598+
)
599+
else:
600+
print(
601+
f"warning: {name} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
602+
+ "and that groupsize and inner_k_tiles*16 evenly divide into it"
603+
)
604+
continue
605+
weight_int4pack, scales_and_zeros = (
606+
WeightOnlyInt4Linear._prepare_weight_and_scales_and_zeros(
607+
weight.to(torch.float), self.groupsize, self.inner_k_tiles
608+
)
609+
)
610+
weight_int4pack = weight_int4pack.to(device=self.device)
611+
scales_and_zeros = scales_and_zeros.to(device=self.device)
612+
613+
setattr(
614+
module,
615+
name,
616+
WeightOnlyInt4Linear(
617+
child.in_features,
618+
child.out_features,
619+
bias=False,
620+
device=self.device,
621+
groupsize=self.groupsize,
622+
inner_k_tiles=self.inner_k_tiles,
623+
weight=weight_int4pack,
624+
scales_and_zeros=scales_and_zeros,
625+
),
626+
)
627+
else:
628+
self.quantize(child)
629+
630+
return module
631+
632+
def quantized_model(self) -> nn.Module:
633+
return self.quantize(self.model_)
634+
635+
548636
def replace_linear_int4(
549637
module,
550638
device,
@@ -563,10 +651,10 @@ def replace_linear_int4(
563651
module,
564652
name,
565653
WeightOnlyInt4Linear(
566-
device,
567654
child.in_features,
568655
child.out_features,
569656
bias=False,
657+
device=device,
570658
groupsize=groupsize,
571659
inner_k_tiles=inner_k_tiles,
572660
),
@@ -1001,9 +1089,9 @@ def quantized_model(self) -> nn.Module:
10011089
# Must come last because __future__ annotations don't work for naked
10021090
# class references
10031091
quantizer_class_dict = {
1004-
"embedding": EmbeddingOnlyInt8QuantHandler,
1092+
"embedding": EmbeddingOnlyQuantHandler,
10051093
"linear:int8": WeightOnlyInt8QuantHandler,
1006-
"linear:int4": WeightOnlyInt4QuantHandler,
1094+
"linear:int4": NewWeightOnlyInt4QuantHandler,
10071095
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
10081096
"linear:int4-gptq": WeightOnlyInt4GPTQQuantHandler,
10091097
"linear:hqq": WeightOnlyInt4HqqQuantHandler,

0 commit comments

Comments
 (0)