Skip to content

Commit 87798fd

Browse files
Update quantize.py to use AO's int4 quantizer (#919)
* Use ao's int4 quantizer * Point AO to commit hash of Jerry's fix * When device is cuda, only run for dtype==bfloat16 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Typo Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Use tensor subclass for int4 weight only quant * Fix bug * Fix * Use both quantizer and subclass API * Bug * unwrap tensor subclass for aoti * Add import * Eval fix * Evaluate AOTI --------- Co-authored-by: Mengwei Liu <[email protected]>
1 parent 53344db commit 87798fd

File tree

2 files changed

+65
-146
lines changed

2 files changed

+65
-146
lines changed

.ci/scripts/validate.sh

Lines changed: 50 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,16 @@ function generate_compiled_model_output() {
9292
python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
9393
.ci/scripts/check_gibberish "$MODEL_DIR/output_compiled"
9494

95-
echo "******************************************"
96-
echo "******** INT4 group-wise quantized *******"
97-
echo "******************************************"
98-
python3 -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1
99-
.ci/scripts/check_gibberish "$MODEL_DIR/output_eager"
100-
python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
101-
.ci/scripts/check_gibberish "$MODEL_DIR/output_compiled"
95+
if [[ $TARGET_DEVICE != "cuda" || "$DTYPE" == "bfloat16" ]]; then
96+
# For CUDA, only bfloat16 makes sense for int4 mm kernel
97+
echo "******************************************"
98+
echo "******** INT4 group-wise quantized *******"
99+
echo "******************************************"
100+
python3 -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1
101+
.ci/scripts/check_gibberish "$MODEL_DIR/output_eager"
102+
python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
103+
.ci/scripts/check_gibberish "$MODEL_DIR/output_compiled"
104+
fi
102105
fi
103106
done
104107
}
@@ -180,12 +183,11 @@ function generate_aoti_model_output() {
180183
echo "******************************************"
181184
echo "******** INT4 group-wise quantized *******"
182185
echo "******************************************"
183-
if [ "$TARGET_DEVICE" == "cuda" ]; then
184-
if [ "$DTYPE" != "float16" ]; then
185-
python3 -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
186-
python3 -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
187-
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
188-
fi
186+
if [[ "$TARGET_DEVICE" != "cuda" || "$DTYPE" == "bfloat16" ]]; then
187+
# For CUDA, only bfloat16 makes sense for int4 mm kernel
188+
python3 -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
189+
python3 -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
190+
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
189191
fi
190192
done
191193
}
@@ -225,21 +227,23 @@ function eval_model() {
225227
echo "perplexity checking succeeded for non-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE"
226228
fi;
227229

228-
echo "******************************************"
229-
echo "******** INT4 group-wise quantized *******"
230-
echo "******************************************"
230+
if [[ "$TARGET_DEVICE" != "cuda" || "$DTYPE" == "bfloat16" ]]; then
231+
echo "******************************************"
232+
echo "******** INT4 group-wise quantized *******"
233+
echo "******************************************"
231234

232-
export QUANT_OPTIONS='{"linear:int4" : {"groupsize": 32}}'
233-
python -W ignore eval.py --compile --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" > "$MODEL_DIR/eval" || exit 1
234-
cat "$MODEL_DIR/eval"
235-
export REF_PERPLEXITY=100000
236-
export PERPLEXITY=cat "$MODEL_DIR/eval" | tail -n 1 log | awk -F '[, ]' '{print $4}'
237-
# == 1 meaning the check succeeded
238-
if [ "$(echo "$PERPLEXITY >= $REF_PERPLEXITY" | bc)" == 1]; then
239-
echo "perplexity checking failed for int4-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE $QUANT_OPTIONS"
240-
else
241-
echo "perplexity checking succeeded for int4-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE $QUANT_OPTIONS"
242-
fi;
235+
export QUANT_OPTIONS='{"linear:int4" : {"groupsize": 32}}'
236+
python -W ignore eval.py --compile --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" > "$MODEL_DIR/eval" || exit 1
237+
cat "$MODEL_DIR/eval"
238+
export REF_PERPLEXITY=100000
239+
export PERPLEXITY=cat "$MODEL_DIR/eval" | tail -n 1 log | awk -F '[, ]' '{print $4}'
240+
# == 1 meaning the check succeeded
241+
if [ "$(echo "$PERPLEXITY >= $REF_PERPLEXITY" | bc)" == 1]; then
242+
echo "perplexity checking failed for int4-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE $QUANT_OPTIONS"
243+
else
244+
echo "perplexity checking succeeded for int4-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE $QUANT_OPTIONS"
245+
fi;
246+
fi
243247

244248
done
245249
}
@@ -260,32 +264,31 @@ function eval_model_sanity_check() {
260264
python -W ignore eval.py --compile --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/eval" || exit 1
261265
cat "$MODEL_DIR/eval"
262266

263-
echo "******************************************"
264-
echo "******** INT4 group-wise quantized *******"
265-
echo "******************************************"
267+
if [[ "$TARGET_DEVICE" != "cuda" || "$DTYPE" == "bfloat16" ]]; then
268+
echo "******************************************"
269+
echo "******** INT4 group-wise quantized *******"
270+
echo "******************************************"
266271

267-
export QUANT_OPTIONS='{"linear:int4" : {"groupsize": 32}}'
268-
python -W ignore eval.py --compile --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/eval" || exit 1
269-
cat "$MODEL_DIR/eval"
272+
export QUANT_OPTIONS='{"linear:int4" : {"groupsize": 32}}'
273+
python -W ignore eval.py --compile --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/eval" || exit 1
274+
cat "$MODEL_DIR/eval"
270275

271-
echo "**************************************************"
272-
echo "******** INT4 group-wise quantized (eager) *******"
273-
echo "**************************************************"
276+
echo "**************************************************"
277+
echo "******** INT4 group-wise quantized (eager) *******"
278+
echo "**************************************************"
274279

275-
if [ "$TARGET_DEVICE" == "cuda" ] && [ "$DTYPE" != "float16" ]; then
276280
python -W ignore eval.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/eval_eager" || exit 1
277281
cat "$MODEL_DIR/eval_eager"
278-
fi;
279282

280-
281-
# there is some issues with AOTI cpu and cuda, need to fix and enable the test for cuda as well
282-
echo "*************************************************"
283-
echo "******** INT4 group-wise quantized (AOTI) *******"
284-
echo "*************************************************"
285-
if [ "$DTYPE" != "float16" ]; then
286-
python3 -W ignore export.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
287-
python3 -W ignore eval.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1
288-
cat "$MODEL_DIR/output_eval_aoti"
283+
# there is some issues with AOTI cpu and cuda, need to fix and enable the test for cuda as well
284+
echo "*************************************************"
285+
echo "******** INT4 group-wise quantized (AOTI) *******"
286+
echo "*************************************************"
287+
if [ "$DTYPE" != "float16" ]; then
288+
python3 -W ignore export.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
289+
python3 -W ignore eval.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1
290+
cat "$MODEL_DIR/output_eval_aoti"
291+
fi;
289292
fi;
290293

291294
done

quantization/quantize.py

Lines changed: 15 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,19 @@
3131
import torch
3232
import torch.nn as nn
3333
import torch.nn.functional as F
34-
from build.utils import (
35-
find_multiple,
36-
get_device_str,
37-
get_precision,
38-
name_to_dtype,
39-
state_dict_device,
40-
)
34+
from build.utils import get_device_str, get_precision, name_to_dtype, state_dict_device
4135

42-
from quantization.qops import (
43-
LinearInt4 as WeightOnlyInt4Linear,
44-
LinearInt8 as WeightOnlyInt8Linear,
45-
QuantizedEmbedding,
46-
)
36+
from quantization.qops import LinearInt8 as WeightOnlyInt8Linear, QuantizedEmbedding
4737

4838
# AttributeError: '_OpNamespace' 'quantized_decomposed' object has no attribute 'quantize_per_channel_group'
4939
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa
50-
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
40+
from torchao.quantization.quant_api import (
41+
int4_weight_only,
42+
Int4WeightOnlyQuantizer,
43+
Int8DynActInt4WeightQuantizer,
44+
quantize_,
45+
)
46+
from torchao.utils import unwrap_tensor_subclass
5147

5248

5349
#########################################################################
@@ -75,6 +71,11 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
7571
):
7672
raise RuntimeError(f"unknown quantizer {quantizer} specified")
7773
if quantizer in ao_quantizer_class_dict:
74+
# Use tensor subclass API for int4 weight only.
75+
if device == "cuda" and quantizer == "linear:int4":
76+
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
77+
unwrap_tensor_subclass(model)
78+
continue
7879
# Use dtype precision specified in user config, else fallback on global precision.
7980
if "precision" in quantize_options:
8081
dtype = quantize_options["precision"].get("dtype", str(get_precision()))
@@ -556,91 +557,6 @@ def quantized_model(self) -> nn.Module:
556557
return self.quantize(self.model_)
557558

558559

559-
#########################################################################
560-
##### weight only int4 per channel groupwise quantized code ######
561-
562-
563-
class WeightOnlyInt4QuantHandler(QuantHandler):
564-
def __init__(
565-
self,
566-
model: nn.Module,
567-
device=None,
568-
*,
569-
tokenizer=None,
570-
groupsize=128,
571-
inner_k_tiles=8,
572-
padding_allowed=True,
573-
):
574-
self.model_ = model
575-
self.device = device
576-
self.groupsize = groupsize
577-
self.inner_k_tiles = inner_k_tiles
578-
self.padding_allowed = padding_allowed
579-
assert groupsize in [32, 64, 128, 256]
580-
assert inner_k_tiles in [2, 4, 8]
581-
582-
@torch.no_grad()
583-
def quantize(self, module):
584-
for name, child in module.named_children():
585-
# print(f"name: {name}")
586-
if isinstance(child, torch.nn.Linear):
587-
assert not child.bias
588-
out_features = child.out_features
589-
in_features = child.in_features
590-
assert out_features % 8 == 0, "require out_features % 8 == 0"
591-
# print(f"linear: {fqn}, in={in_features}, out={out_features}")
592-
593-
weight = child.weight.data
594-
if not WeightOnlyInt4Linear._check_k(
595-
k=in_features,
596-
groupsize=self.groupsize,
597-
inner_k_tiles=self.inner_k_tiles,
598-
):
599-
if self.padding_allowed:
600-
# print(
601-
# f"warning: {name} is padded to satisfy in_features % 1024 == 0"
602-
# )
603-
padded_in_features = find_multiple(in_features, 1024)
604-
weight = F.pad(
605-
weight, pad=(0, padded_in_features - in_features)
606-
)
607-
else:
608-
print(
609-
f"warning: {name} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
610-
+ "and that groupsize and inner_k_tiles*16 evenly divide into it"
611-
)
612-
continue
613-
weight_int4pack, scales_and_zeros = (
614-
WeightOnlyInt4Linear._prepare_weight_and_scales_and_zeros(
615-
weight.to(torch.float), self.groupsize, self.inner_k_tiles
616-
)
617-
)
618-
weight_int4pack = weight_int4pack.to(device=self.device)
619-
scales_and_zeros = scales_and_zeros.to(device=self.device)
620-
621-
setattr(
622-
module,
623-
name,
624-
WeightOnlyInt4Linear(
625-
child.in_features,
626-
child.out_features,
627-
bias=False,
628-
device=self.device,
629-
groupsize=self.groupsize,
630-
inner_k_tiles=self.inner_k_tiles,
631-
weight=weight_int4pack,
632-
scales_and_zeros=scales_and_zeros,
633-
),
634-
)
635-
else:
636-
self.quantize(child)
637-
638-
return module
639-
640-
def quantized_model(self) -> nn.Module:
641-
return self.quantize(self.model_)
642-
643-
644560
##########################################################################
645561
### quantization dictionary ###
646562

@@ -650,11 +566,11 @@ def quantized_model(self) -> nn.Module:
650566
quantizer_class_dict = {
651567
"embedding": EmbeddingOnlyQuantHandler,
652568
"linear:int8": WeightOnlyInt8QuantHandler,
653-
"linear:int4": WeightOnlyInt4QuantHandler,
654569
"precision": PrecisionHandler,
655570
"executor": ExecutorHandler,
656571
}
657572

658573
ao_quantizer_class_dict = {
574+
"linear:int4": Int4WeightOnlyQuantizer,
659575
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
660576
}

0 commit comments

Comments
 (0)