Skip to content

Update quantize.py to use AO's int4 quantizer #919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 19, 2024
97 changes: 50 additions & 47 deletions .ci/scripts/validate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,16 @@ function generate_compiled_model_output() {
python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
.ci/scripts/check_gibberish "$MODEL_DIR/output_compiled"

echo "******************************************"
echo "******** INT4 group-wise quantized *******"
echo "******************************************"
python3 -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1
.ci/scripts/check_gibberish "$MODEL_DIR/output_eager"
python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
.ci/scripts/check_gibberish "$MODEL_DIR/output_compiled"
if [[ $TARGET_DEVICE != "cuda" || "$DTYPE" == "bfloat16" ]]; then
# For CUDA, only bfloat16 makes sense for int4 mm kernel
echo "******************************************"
echo "******** INT4 group-wise quantized *******"
echo "******************************************"
python3 -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1
.ci/scripts/check_gibberish "$MODEL_DIR/output_eager"
python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
.ci/scripts/check_gibberish "$MODEL_DIR/output_compiled"
fi
fi
done
}
Expand Down Expand Up @@ -180,12 +183,11 @@ function generate_aoti_model_output() {
echo "******************************************"
echo "******** INT4 group-wise quantized *******"
echo "******************************************"
if [ "$TARGET_DEVICE" == "cuda" ]; then
if [ "$DTYPE" != "float16" ]; then
python3 -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
python3 -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
fi
if [[ "$TARGET_DEVICE" != "cuda" || "$DTYPE" == "bfloat16" ]]; then
# For CUDA, only bfloat16 makes sense for int4 mm kernel
python3 -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
python3 -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
fi
done
}
Expand Down Expand Up @@ -225,21 +227,23 @@ function eval_model() {
echo "perplexity checking succeeded for non-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE"
fi;

echo "******************************************"
echo "******** INT4 group-wise quantized *******"
echo "******************************************"
if [[ "$TARGET_DEVICE" != "cuda" || "$DTYPE" == "bfloat16" ]]; then
echo "******************************************"
echo "******** INT4 group-wise quantized *******"
echo "******************************************"

export QUANT_OPTIONS='{"linear:int4" : {"groupsize": 32}}'
python -W ignore eval.py --compile --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" > "$MODEL_DIR/eval" || exit 1
cat "$MODEL_DIR/eval"
export REF_PERPLEXITY=100000
export PERPLEXITY=cat "$MODEL_DIR/eval" | tail -n 1 log | awk -F '[, ]' '{print $4}'
# == 1 meaning the check succeeded
if [ "$(echo "$PERPLEXITY >= $REF_PERPLEXITY" | bc)" == 1]; then
echo "perplexity checking failed for int4-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE $QUANT_OPTIONS"
else
echo "perplexity checking succeeded for int4-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE $QUANT_OPTIONS"
fi;
export QUANT_OPTIONS='{"linear:int4" : {"groupsize": 32}}'
python -W ignore eval.py --compile --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" > "$MODEL_DIR/eval" || exit 1
cat "$MODEL_DIR/eval"
export REF_PERPLEXITY=100000
export PERPLEXITY=cat "$MODEL_DIR/eval" | tail -n 1 log | awk -F '[, ]' '{print $4}'
# == 1 meaning the check succeeded
if [ "$(echo "$PERPLEXITY >= $REF_PERPLEXITY" | bc)" == 1]; then
echo "perplexity checking failed for int4-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE $QUANT_OPTIONS"
else
echo "perplexity checking succeeded for int4-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE $QUANT_OPTIONS"
fi;
fi

done
}
Expand All @@ -260,32 +264,31 @@ function eval_model_sanity_check() {
python -W ignore eval.py --compile --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/eval" || exit 1
cat "$MODEL_DIR/eval"

echo "******************************************"
echo "******** INT4 group-wise quantized *******"
echo "******************************************"
if [[ "$TARGET_DEVICE" != "cuda" || "$DTYPE" == "bfloat16" ]]; then
echo "******************************************"
echo "******** INT4 group-wise quantized *******"
echo "******************************************"

export QUANT_OPTIONS='{"linear:int4" : {"groupsize": 32}}'
python -W ignore eval.py --compile --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/eval" || exit 1
cat "$MODEL_DIR/eval"
export QUANT_OPTIONS='{"linear:int4" : {"groupsize": 32}}'
python -W ignore eval.py --compile --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/eval" || exit 1
cat "$MODEL_DIR/eval"

echo "**************************************************"
echo "******** INT4 group-wise quantized (eager) *******"
echo "**************************************************"
echo "**************************************************"
echo "******** INT4 group-wise quantized (eager) *******"
echo "**************************************************"

if [ "$TARGET_DEVICE" == "cuda" ] && [ "$DTYPE" != "float16" ]; then
python -W ignore eval.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/eval_eager" || exit 1
cat "$MODEL_DIR/eval_eager"
fi;


# there is some issues with AOTI cpu and cuda, need to fix and enable the test for cuda as well
echo "*************************************************"
echo "******** INT4 group-wise quantized (AOTI) *******"
echo "*************************************************"
if [ "$DTYPE" != "float16" ]; then
python3 -W ignore export.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
python3 -W ignore eval.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1
cat "$MODEL_DIR/output_eval_aoti"
# there is some issues with AOTI cpu and cuda, need to fix and enable the test for cuda as well
echo "*************************************************"
echo "******** INT4 group-wise quantized (AOTI) *******"
echo "*************************************************"
if [ "$DTYPE" != "float16" ]; then
python3 -W ignore export.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
python3 -W ignore eval.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1
cat "$MODEL_DIR/output_eval_aoti"
fi;
fi;

done
Expand Down
2 changes: 1 addition & 1 deletion install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ $PIP_EXECUTABLE install --extra-index-url "${TORCH_NIGHTLY_URL}" \

# For torchao need to install from github since nightly build doesn't have macos build.
# TODO: Remove this and install nightly build, once it supports macos
$PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@ca1b98db60543a1669a32e842762fc008c178376
$PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@d36de1b144b73bf753bd082109c2b5d0141abd5b
if [[ -x "$(command -v nvidia-smi)" ]]; then
$PYTHON_EXECUTABLE scripts/patch_triton.py
fi
114 changes: 15 additions & 99 deletions quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,19 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from build.utils import (
find_multiple,
get_device_str,
get_precision,
name_to_dtype,
state_dict_device,
)
from build.utils import get_device_str, get_precision, name_to_dtype, state_dict_device

from quantization.qops import (
LinearInt4 as WeightOnlyInt4Linear,
LinearInt8 as WeightOnlyInt8Linear,
QuantizedEmbedding,
)
from quantization.qops import LinearInt8 as WeightOnlyInt8Linear, QuantizedEmbedding

# AttributeError: '_OpNamespace' 'quantized_decomposed' object has no attribute 'quantize_per_channel_group'
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
from torchao.quantization.quant_api import (
int4_weight_only,
Int4WeightOnlyQuantizer,
Int8DynActInt4WeightQuantizer,
quantize_,
)
from torchao.utils import unwrap_tensor_subclass


#########################################################################
Expand Down Expand Up @@ -59,6 +55,11 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
):
raise RuntimeError(f"unknown quantizer {quantizer} specified")
if quantizer in ao_quantizer_class_dict:
# Use tensor subclass API for int4 weight only.
if device == "cuda" and quantizer == "linear:int4":
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
unwrap_tensor_subclass(model)
continue
# Use dtype precision specified in user config, else fallback on global precision.
if "precision" in quantize_options:
dtype = quantize_options["precision"].get("dtype", str(get_precision()))
Expand Down Expand Up @@ -540,91 +541,6 @@ def quantized_model(self) -> nn.Module:
return self.quantize(self.model_)


#########################################################################
##### weight only int4 per channel groupwise quantized code ######


class WeightOnlyInt4QuantHandler(QuantHandler):
def __init__(
self,
model: nn.Module,
device=None,
*,
tokenizer=None,
groupsize=128,
inner_k_tiles=8,
padding_allowed=True,
):
self.model_ = model
self.device = device
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
self.padding_allowed = padding_allowed
assert groupsize in [32, 64, 128, 256]
assert inner_k_tiles in [2, 4, 8]

@torch.no_grad()
def quantize(self, module):
for name, child in module.named_children():
# print(f"name: {name}")
if isinstance(child, torch.nn.Linear):
assert not child.bias
out_features = child.out_features
in_features = child.in_features
assert out_features % 8 == 0, "require out_features % 8 == 0"
# print(f"linear: {fqn}, in={in_features}, out={out_features}")

weight = child.weight.data
if not WeightOnlyInt4Linear._check_k(
k=in_features,
groupsize=self.groupsize,
inner_k_tiles=self.inner_k_tiles,
):
if self.padding_allowed:
# print(
# f"warning: {name} is padded to satisfy in_features % 1024 == 0"
# )
padded_in_features = find_multiple(in_features, 1024)
weight = F.pad(
weight, pad=(0, padded_in_features - in_features)
)
else:
print(
f"warning: {name} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
+ "and that groupsize and inner_k_tiles*16 evenly divide into it"
)
continue
weight_int4pack, scales_and_zeros = (
WeightOnlyInt4Linear._prepare_weight_and_scales_and_zeros(
weight.to(torch.float), self.groupsize, self.inner_k_tiles
)
)
weight_int4pack = weight_int4pack.to(device=self.device)
scales_and_zeros = scales_and_zeros.to(device=self.device)

setattr(
module,
name,
WeightOnlyInt4Linear(
child.in_features,
child.out_features,
bias=False,
device=self.device,
groupsize=self.groupsize,
inner_k_tiles=self.inner_k_tiles,
weight=weight_int4pack,
scales_and_zeros=scales_and_zeros,
),
)
else:
self.quantize(child)

return module

def quantized_model(self) -> nn.Module:
return self.quantize(self.model_)


##########################################################################
### quantization dictionary ###

Expand All @@ -634,11 +550,11 @@ def quantized_model(self) -> nn.Module:
quantizer_class_dict = {
"embedding": EmbeddingOnlyQuantHandler,
"linear:int8": WeightOnlyInt8QuantHandler,
"linear:int4": WeightOnlyInt4QuantHandler,
"precision": PrecisionHandler,
"executor": ExecutorHandler,
}

ao_quantizer_class_dict = {
"linear:int4": Int4WeightOnlyQuantizer,
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
}
Loading