Skip to content

Commit 4d2f76c

Browse files
mikekgfbmalfet
authored andcommitted
fix cmake version, padding for a8w4dq, lint... (#680)
* fix cmake version, p[adding for a8w4dq, lint... * fix cmake version, p[adding for a8w4dq, lint... * updates * fix
1 parent 66ed732 commit 4d2f76c

File tree

11 files changed

+259
-37
lines changed

11 files changed

+259
-37
lines changed

.github/workflows/run-readme-pr-macos.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Run the README instructions - with stories - to ensure they work
1+
name: Run the README instructions - with stories - on MacOS
22
on:
33
pull_request:
44
push:

.github/workflows/run-readme-pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Run the README instructions - with stories - to ensure they work
1+
name: Run the README instructions - with stories
22

33
on:
44
pull_request:

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ export TORCHCHAT_ROOT=${PWD}
199199
### Export for mobile
200200
The following example uses the Llama3 8B Instruct model.
201201

202-
[shell default]: echo '{"embedding": {"bitwidth": 4, "groupsize" : 32}, "linear:a8w4dq": {"groupsize" : 32}}' >./config/data/mobile.json
202+
[#shell default]: echo '{"embedding": {"bitwidth": 4, "groupsize" : 32}, "linear:a8w4dq": {"groupsize" : 32}}' >./config/data/mobile.json
203203

204204
```
205205
# Export

build/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
import os
1111
from pathlib import Path
1212
from typing import Any, Callable, Dict, List, Tuple
13+
1314
import torch
1415

1516
##########################################################################
1617
### unpack packed weights ###
1718

19+
1820
def unpack_packed_weights(
1921
packed_weights: Dict[str, Any],
2022
packed_linear: Callable,

cli.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import json
8+
import logging
89
import os
910
from pathlib import Path
1011

@@ -13,6 +14,13 @@
1314
from build.utils import allowable_dtype_names, allowable_params_table, get_device_str
1415
from download import download_and_convert, is_model_downloaded
1516

17+
FORMAT = (
18+
"%(levelname)s: %(asctime)-15s: %(filename)s: %(funcName)s: %(module)s: %(message)s"
19+
)
20+
logging.basicConfig(filename="/tmp/torchchat.log", level=logging.INFO, format=FORMAT)
21+
logger = logging.getLogger(__name__)
22+
23+
1624
default_device = "fast"
1725
default_model_dir = Path(
1826
os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache")
@@ -316,7 +324,7 @@ def arg_init(args):
316324
if args.output_pte_path:
317325
if args.device not in ["cpu", "fast"]:
318326
raise RuntimeError("Device not supported by ExecuTorch")
319-
args.device="cpu"
327+
args.device = "cpu"
320328
else:
321329
args.device = get_device_str(
322330
args.quantize.get("executor", {}).get("accelerator", args.device)

docs/quantization.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ While quantization can potentially degrade the model's performance, the methods
1111
| compression | FP Precision | bitwidth| group size | dynamic activation quantization | Eager | AOTI | ExecuTorch |
1212
|--|--|--|--|--|--|--|--|
1313
| linear (asymmetric) | fp32, fp16, bf16 | [8, 4]* | [32, 64, 128, 256]** | ||| 🚧 |
14-
| linear with dynamic activations (symmetric) | fp32^ | | [32, 64, 128, 256]** | a8w4dq | 🚧 |🚧 ||
1514
| linear with GPTQ*** (asymmetric) | | |[32, 64, 128, 256]** | ||||
1615
| linear with HQQ*** (asymmetric) | | |[32, 64, 128, 256]** | ||||
16+
| linear with dynamic activations (symmetric) | fp32^ | | [32, 64, 128, 256] | a8w4dq | 🚧 |🚧 ||
1717

1818
### Embedding Quantization
1919
Due to the larger vocabulary size of llama3, we also recommend quantizing the embeddings to further reduce the model size for on-device usecases.

generate.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66
import argparse
77
import itertools
8-
98
import logging
109
import sys
1110
import time
@@ -25,9 +24,7 @@
2524
)
2625
from build.model import Transformer
2726
from build.utils import device_sync, set_precision
28-
from cli import add_arguments_for_generate, arg_init, check_args
29-
30-
logger = logging.getLogger(__name__)
27+
from cli import add_arguments_for_generate, arg_init, check_args, logger
3128

3229
B_INST, E_INST = "[INST]", "[/INST]"
3330
B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"

qops.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,3 +405,134 @@ def _prepare_weight_and_scales_and_zeros(
405405
@classmethod
406406
def _calc_padded_size(cls, *, k, groupsize=1, innner_k_tiles=1):
407407
return find_multiple(k, 1024)
408+
409+
410+
def linear_8da4w(
411+
input,
412+
weight_int8,
413+
scales,
414+
zeros,
415+
out_features,
416+
groupsize,
417+
precision,
418+
):
419+
from torchao.quantization.quant_primitives import per_token_dynamic_quant
420+
421+
input = per_token_dynamic_quant(input)
422+
# TODO: verify and remove following reshape code
423+
# origin_input_size = input.size()
424+
# input = input.reshape(-1, origin_input_size[-1])
425+
426+
# TODO: better API
427+
# weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed)
428+
n_bit = 4
429+
quant_min = -(2 ** (n_bit - 1))
430+
quant_max = 2 ** (n_bit - 1) - 1
431+
w_dq = torch.ops.quantized_decomposed.dequantize_per_channel_group(
432+
weight_int8,
433+
scales,
434+
zeros,
435+
quant_min,
436+
quant_max,
437+
torch.int8,
438+
groupsize,
439+
precision,
440+
)
441+
442+
# input = input.to(torch.float16)
443+
# w_dq = w_dq.to(torch.float16)
444+
c = torch.nn.functional.linear(input, w_dq)
445+
446+
# new_shape = origin_input_size[:-1] + (out_features,)
447+
# c = c.reshape(new_shape)
448+
449+
return c
450+
451+
452+
class LinearAct8Int4DQ(torch.nn.Module):
453+
__constants__ = ["in_features", "origin_in_feature", "out_features"]
454+
in_features: int
455+
origin_in_features: int
456+
out_features: int
457+
weight: torch.Tensor
458+
scales: torch.Tensor
459+
zeros: torch.Tensor
460+
461+
"""
462+
This module implements a dynamic quantized linear layer with
463+
int4 weight. Weights are per channel groupwise
464+
quantized. Parameters of importance groupsize: the number of
465+
elements in each quantized group precision: precision of input and
466+
output. e.g. torch.float32 means input activation is float32 and
467+
output is float32. scales_precision: precision of per group
468+
scale. """
469+
470+
def __init__(
471+
self,
472+
in_features: int,
473+
out_features: int,
474+
bias=True,
475+
device=None,
476+
dtype=None,
477+
*,
478+
groupsize: int = 256,
479+
weight: Optional[torch.Tensor] = None,
480+
scales: Optional[torch.Tensor] = None,
481+
precision: torch.dtype = torch.float32,
482+
scales_precision: torch.dtype = torch.float32,
483+
) -> None:
484+
super().__init__()
485+
# always pad if needed since it becomes a noop at runtime if not needed
486+
# self.origin_in_features = in_features
487+
self.origin_in_features = in_features
488+
in_features = find_multiple(in_features, groupsize)
489+
self.in_features = in_features
490+
self.in_features = in_features
491+
self.out_features = out_features
492+
assert not bias, "require bias=False"
493+
494+
self.groupsize = groupsize
495+
# Precision of the activation which also indicates
496+
# output precision of the dynamically quantized linear layer
497+
# that his module represents.
498+
self.precision = precision
499+
500+
assert (weight is None) == bool(
501+
scales is None
502+
), "must specify both weights and scales_and_zeros, or neither"
503+
504+
if weight is None:
505+
weight = torch.empty((out_features, in_features), dtype=torch.int8)
506+
scales = torch.empty(
507+
(out_features, in_features // groupsize),
508+
dtype=scales_precision,
509+
)
510+
511+
# we received an unpadded weight, so pad it
512+
if weight.shape[1] != in_features:
513+
weight = F.pad(weight, pad=(0, self.in_features - self.origin_in_features))
514+
515+
# currently storing unpacked int8 weights
516+
self.register_buffer("weight", weight)
517+
self.register_buffer("scales", scales)
518+
self.register_buffer(
519+
"zeros",
520+
torch.empty(
521+
(out_features, in_features // groupsize),
522+
dtype=scales_precision,
523+
),
524+
)
525+
526+
def forward(self, input: torch.Tensor) -> torch.Tensor:
527+
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
528+
# This operator does not support anything but FP32, so we do the deed
529+
# Eventually push that into linear_8da4w
530+
return linear_8da4w(
531+
input.float(),
532+
self.weight,
533+
self.scales,
534+
self.zeros,
535+
self.out_features,
536+
self.groupsize,
537+
self.precision,
538+
).to(dtype=input.dtype)

quantize.py

Lines changed: 109 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525

2626
from qops import (
27+
LinearAct8Int4DQ,
2728
LinearInt4 as WeightOnlyInt4Linear,
2829
LinearInt8 as WeightOnlyInt8Linear,
2930
QuantizedEmbedding,
@@ -83,29 +84,29 @@ def quantized_model(self) -> nn.Module:
8384

8485
#########################################################################
8586
### QuantHandler wrapper for a8w4dq from torchao ###
86-
87-
88-
class Int8DynActInt4WeightQuantizer(QuantHandler):
89-
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, **kwargs):
90-
import torchao.quantization.quant_api as quant_api
91-
92-
self.model_ = model
93-
self.device = device
94-
self.tokenizer = tokenizer
95-
self.quantizer = quant_api.Int8DynActInt4WeightQuantizer(
96-
**kwargs, precision=get_precision(), scales_precision=get_precision()
97-
)
98-
99-
def create_quantized_state_dict(self) -> Dict: # "StateDict"
100-
pass
101-
102-
def convert_for_runtime(self) -> nn.Module:
103-
pass
104-
105-
def quantized_model(self) -> nn.Module:
106-
return self.quantizer.quantize(self.model_)
107-
108-
87+
#
88+
#
89+
# class Int8DynActInt4WeightQuantizer(QuantHandler):
90+
# def __init__(self, model: nn.Module, device="cpu", tokenizer=None, **kwargs):
91+
# import torchao.quantization.quant_api as quant_api
92+
#
93+
# self.model_ = model
94+
# self.device = device
95+
# self.tokenizer = tokenizer
96+
# self.quantizer = quant_api.Int8DynActInt4WeightQuantizer(
97+
# **kwargs, precision=get_precision(), scales_precision=get_precision()
98+
# )
99+
#
100+
# def create_quantized_state_dict(self) -> Dict: # "StateDict"
101+
# pass
102+
#
103+
# def convert_for_runtime(self) -> nn.Module:
104+
# pass
105+
#
106+
# def quantized_model(self) -> nn.Module:
107+
# return self.quantizer.quantize(self.model_)
108+
#
109+
#
109110
#########################################################################
110111
### wrapper for setting precision as a QuantHandler ###
111112

@@ -547,8 +548,6 @@ def __init__(
547548
groupsize=128,
548549
inner_k_tiles=8,
549550
padding_allowed=True,
550-
weight: Optional[torch.Tensor] = None,
551-
scales_and_zeros: Optional[torch.Tensor] = None,
552551
):
553552
self.model_ = model
554553
self.device = device
@@ -620,6 +619,91 @@ def quantized_model(self) -> nn.Module:
620619
return self.quantize(self.model_)
621620

622621

622+
#########################################################################
623+
##### weight only int4 per channel groupwise quantized code ######
624+
625+
626+
class Int8DynActInt4WeightQuantizer(QuantHandler):
627+
def __init__(
628+
self,
629+
model: nn.Module,
630+
device=None,
631+
dtype=None,
632+
*,
633+
tokenizer=None,
634+
groupsize=128,
635+
padding_allowed=True,
636+
precision=torch.float32,
637+
scales_precision=torch.float32,
638+
):
639+
if dtype is None:
640+
dtype = torch.float32
641+
642+
self.model_ = model
643+
self.device = device
644+
self.dtype = dtype
645+
646+
self.groupsize = groupsize
647+
self.padding_allowed = padding_allowed
648+
self.precision = precision
649+
self.scales_precision = scales_precision
650+
assert groupsize in [32, 64, 128, 256]
651+
652+
@torch.no_grad()
653+
def quantize(self, module):
654+
from torchao.quantization.quant_primitives import (
655+
group_quantize_tensor_symmetric,
656+
)
657+
658+
for name, child in module.named_children():
659+
# print(f"name: {name}")
660+
if isinstance(child, torch.nn.Linear):
661+
out_features = child.out_features
662+
in_features = child.in_features
663+
weight = child.weight.data
664+
assert not child.bias
665+
assert out_features % 8 == 0, "require out_features % 8 == 0"
666+
# print(f"linear: {fqn}, in={in_features}, out={out_features}")
667+
668+
# if self.padding_allowed:
669+
# padding_multiple=max(self.groupsize, 1024)
670+
padding_multiple = self.groupsize
671+
padded_in_features = find_multiple(in_features, padding_multiple)
672+
weight = F.pad(weight, pad=(0, padded_in_features - in_features))
673+
(
674+
weight_int8,
675+
scales,
676+
zeros,
677+
) = group_quantize_tensor_symmetric(
678+
weight.float(),
679+
4, # n_bit
680+
self.groupsize,
681+
self.scales_precision,
682+
)
683+
684+
setattr(
685+
module,
686+
name,
687+
LinearAct8Int4DQ(
688+
child.in_features,
689+
child.out_features,
690+
bias=False,
691+
device=self.device,
692+
dtype=self.dtype,
693+
groupsize=self.groupsize,
694+
weight=weight_int8.to(device=self.device),
695+
scales=scales.to(device=self.device),
696+
),
697+
)
698+
else:
699+
self.quantize(child)
700+
701+
return module
702+
703+
def quantized_model(self) -> nn.Module:
704+
return self.quantize(self.model_)
705+
706+
623707
#########################################################################
624708
##### GPTQ #####
625709

0 commit comments

Comments
 (0)