Skip to content

Commit 68ee0b0

Browse files
kimishpatelmalfet
authored andcommitted
Enable 8a4wdq (#264)
Summary: - Removed Int8DynActInt4Weight code - Use torchao to achieve the same Test Plan: python export.py --quant '{"linear:a8w4dq" : {"groupsize": 128}}' --checkpoint-path stories110M.pt --params-path params.json --output-pte-path /tmp/stories110m_a8w4dq.pte Run ./build/cmake-out/runner_et /tmp/stories110m_a8w4dq.pte -z /tmp/tokenizer.bin -n 200 -t 0 Reviewers: Subscribers: Tasks: Tags:
1 parent f71a8b9 commit 68ee0b0

File tree

2 files changed

+18
-339
lines changed

2 files changed

+18
-339
lines changed

.github/workflows/pull.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,10 @@ jobs:
285285
cat ./output_et
286286
287287
echo "******************************************"
288-
echo "******** INT4 group-wise quantized *******"
288+
echo "******** ET: a8w4dq INT4 group-wise quantized *******"
289289
echo "******************************************"
290-
# python export.py --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --output-pte-path ${MODEL_DIR}/${MODEL_NAME}.pte
291-
# python generate.py --checkpoint-path ${MODEL_PATH} --temperature 0 --pte-path ${MODEL_DIR}/${MODEL_NAME}.pte > ./output_et
290+
python export.py --quant '{"linear:a8w4dq" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --output-pte-path ${MODEL_DIR}/${MODEL_NAME}.pte
291+
python generate.py --checkpoint-path ${MODEL_PATH} --temperature 0 --pte-path ${MODEL_DIR}/${MODEL_NAME}.pte > ./output_et
292292
# cat ./output_et
293293
294294
echo "tests complete"

quantize.py

Lines changed: 15 additions & 336 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,20 @@ def quantize_model(model: nn.Module, device, quantize_options):
9292
).quantized_model()
9393
elif quantizer == "linear:a8w4dq":
9494
linears_quantized = True
95-
model = Int8DynActInt4WeightQuantHandler(
96-
model, device, **q_kwargs
97-
).quantized_model()
95+
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
96+
# Note that Int8DynActInt4WeightQuantizer takes precision as
97+
# arg, which is used to determine the precision/dtype of the output
98+
# That is, if dtype=fp32 than this dynamically quantized linear will
99+
# return output tensor with fp32 dtype.
100+
# Ideally we make this dynamic such that the output dtype is determined
101+
# based on the input dtype, instead of having to instantiate quantizer
102+
# that picks the output dtype.
103+
# Since this require change in torchao, we leave the current state as is
104+
# and use the default precision for Int8DynActInt4WeightQuantizer
105+
# which is fp32.
106+
assert 'groupsize' in list(q_kwargs.keys()), f"a8w4dq quantization option must specify groupsize. Specified options {q_kwargs}"
107+
model = Int8DynActInt4WeightQuantizer(groupsize=q_kwargs['groupsize']
108+
).quantize(model)
98109
elif quantizer == "linear:gptq":
99110
linears_quantized = True
100111
model = WeightOnlyInt4GPTQQuantHandler(
@@ -968,273 +979,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
968979

969980

970981
#########################################################################
971-
##### Int8 Dynamic Activations 4 Bit Weights #####
972-
973-
974-
def prepare_int4_weight_and_scales_and_zeros(weight, groupsize, precision):
975-
weight_int8, scales, zeros = group_quantize_tensor_symmetric(
976-
weight,
977-
n_bit=4,
978-
groupsize=groupsize,
979-
precision=precision,
980-
)
981-
# TODO: better API
982-
# weight_int4packed = torch.ops.quantized_decomposed.pack_int4_from_int8(weight_int8)
983-
return weight_int8, scales, zeros
984-
985-
986-
def linear_forward_8da4w(
987-
x, weight_int8, scales, zeros, out_features, groupsize, precision
988-
):
989-
x = per_token_dynamic_quant(x)
990-
# TODO: verify and remove following reshape code
991-
# origin_x_size = x.size()
992-
# x = x.reshape(-1, origin_x_size[-1])
993-
994-
# TODO: better API
995-
# weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed)
996-
n_bit = 4
997-
quant_min = -(2 ** (n_bit - 1))
998-
quant_max = 2 ** (n_bit - 1) - 1
999-
w_dq = torch.ops.quantized_decomposed.dequantize_per_channel_group(
1000-
weight_int8,
1001-
scales,
1002-
zeros,
1003-
quant_min,
1004-
quant_max,
1005-
torch.int8,
1006-
groupsize,
1007-
precision,
1008-
)
1009-
1010-
# x = x.to(torch.float16)
1011-
# w_dq = w_dq.to(torch.float16)
1012-
c = torch.nn.functional.linear(x, w_dq)
1013-
1014-
# new_shape = origin_x_size[:-1] + (out_features,)
1015-
# c = c.reshape(new_shape)
1016-
1017-
return c
1018-
1019-
1020-
def find_multiple(n: int, *args: Tuple[int]) -> int:
1021-
k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9]
1022-
if n % k == 0:
1023-
return n
1024-
return n + k - (n % k)
1025-
982+
##### GPTQ #####
1026983

1027984
def _check_linear_int4_k(k, groupsize=1):
1028985
return k % groupsize == 0
1029986

1030987

1031-
def _calc_padded_size_linear_int4(k, groupsize=1):
1032-
return find_multiple(k, groupsize)
1033-
1034-
1035-
def replace_linear_8da4w(
1036-
module,
1037-
groupsize,
1038-
padding_allowed,
1039-
precision,
1040-
scales_precision,
1041-
):
1042-
for name, child in module.named_children():
1043-
if isinstance(child, nn.Linear):
1044-
if _check_linear_int4_k(child.in_features, groupsize) or padding_allowed:
1045-
setattr(
1046-
module,
1047-
name,
1048-
Int8DynActInt4WeightLinear(
1049-
child.in_features,
1050-
child.out_features,
1051-
bias=False,
1052-
groupsize=groupsize,
1053-
precision=precision,
1054-
scales_precision=scales_precision,
1055-
),
1056-
)
1057-
else:
1058-
replace_linear_8da4w(
1059-
child,
1060-
groupsize,
1061-
padding_allowed,
1062-
precision,
1063-
scales_precision,
1064-
)
1065-
1066-
1067-
class Int8DynActInt4WeightQuantHandler(QuantHandler):
1068-
def __init__(
1069-
self,
1070-
mod,
1071-
device,
1072-
*,
1073-
groupsize=256,
1074-
padding_allowed=False,
1075-
precision=torch.float32,
1076-
scales_precision=torch.float32,
1077-
):
1078-
self.mod = mod
1079-
self.device = device
1080-
self.groupsize = groupsize
1081-
self.padding_allowed = padding_allowed
1082-
self.precision = precision
1083-
self.scales_precision = scales_precision
1084-
# assert groupsize in [32, 64, 128, 256]
1085-
1086-
@torch.no_grad()
1087-
def create_quantized_state_dict(self):
1088-
cur_state_dict = self.mod.state_dict()
1089-
for fqn, mod in self.mod.named_modules():
1090-
if isinstance(mod, torch.nn.Linear):
1091-
assert not mod.bias
1092-
in_features = mod.in_features
1093-
# print("in features:", in_features, " out features:", out_features)
1094-
# assert out_features % 8 == 0, "require out_features % 8 == 0"
1095-
# print(f"linear: {fqn}, in={in_features}, out={out_features}")
1096-
1097-
assert (
1098-
in_features % self.groupsize == 0
1099-
), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0"
1100-
1101-
weight = mod.weight.data
1102-
"""
1103-
if not _check_linear_int4_k(
1104-
in_features, self.groupsize
1105-
):
1106-
if self.padding_allowed:
1107-
print(
1108-
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
1109-
)
1110-
padded_in_features = _calc_padded_size_linear_int4(
1111-
in_features, self.groupsize
1112-
)
1113-
weight = F.pad(
1114-
weight, pad=(0, padded_in_features - in_features)
1115-
)
1116-
else:
1117-
raise RuntimeError(
1118-
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
1119-
+ "and that groupsize"
1120-
)
1121-
"""
1122-
(
1123-
weight_int4pack,
1124-
scales,
1125-
zeros,
1126-
) = prepare_int4_weight_and_scales_and_zeros(
1127-
weight.to(self.precision),
1128-
self.groupsize,
1129-
self.scales_precision,
1130-
)
1131-
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
1132-
cur_state_dict[f"{fqn}.scales"] = scales.to("cpu")
1133-
cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu")
1134-
1135-
return cur_state_dict
1136-
1137-
def convert_for_runtime(self):
1138-
replace_linear_8da4w(
1139-
self.mod,
1140-
self.groupsize,
1141-
self.padding_allowed,
1142-
self.precision,
1143-
self.scales_precision,
1144-
)
1145-
return self.mod
1146-
1147-
def quantized_model(self) -> nn.Module:
1148-
model_updated_state_dict = self.create_quantized_state_dict()
1149-
self.convert_for_runtime()
1150-
self.mod.load_state_dict(model_updated_state_dict)
1151-
return self.mod
1152-
1153-
1154-
class Int8DynActInt4WeightLinear(torch.nn.Module):
1155-
__constants__ = ["in_features", "out_features"]
1156-
1157-
in_features: int
1158-
out_features: int
1159-
weight: torch.Tensor
1160-
1161-
"""
1162-
This module implements a dynamic quantized linear layer with int4 weight.
1163-
Weights are per channel groupwise quantized. Parameters of importance
1164-
groupsize: the number of elements in each quantized group
1165-
precision: precision of input and output. e.g. torch.float32 means input
1166-
activation is float32 and output is float32.
1167-
scales_precision: precision of per group scale.
1168-
"""
1169-
1170-
def __init__(
1171-
self,
1172-
in_features: int,
1173-
out_features: int,
1174-
bias=True,
1175-
device=None,
1176-
dtype=None,
1177-
groupsize: int = 256,
1178-
precision: torch.dtype = torch.float32,
1179-
scales_precision: torch.dtype = torch.float32,
1180-
) -> None:
1181-
super().__init__()
1182-
# always pad if needed since it becomes a noop at runtime if not needed
1183-
# self.origin_in_features = in_features
1184-
assert (
1185-
in_features % groupsize == 0
1186-
), f"require in_features:{in_features} % groupsize:{groupsize} == 0"
1187-
# in_features = _calc_padded_size_linear_int4(
1188-
# in_features, groupsize
1189-
# )
1190-
self.in_features = in_features
1191-
self.out_features = out_features
1192-
assert not bias, "require bias=False"
1193-
self.groupsize = groupsize
1194-
# Precision of the activation which also indicates
1195-
# output precision of the dynamically quantized linear layer
1196-
# that his module represents.
1197-
self.precision = precision
1198-
1199-
# currently storing unpacked int8 weights
1200-
self.register_buffer(
1201-
"weight",
1202-
torch.empty((out_features, in_features), dtype=torch.int8),
1203-
)
1204-
self.register_buffer(
1205-
"scales",
1206-
torch.empty(
1207-
(out_features, in_features // groupsize),
1208-
dtype=scales_precision,
1209-
),
1210-
)
1211-
self.register_buffer(
1212-
"zeros",
1213-
torch.empty(
1214-
(out_features, in_features // groupsize),
1215-
dtype=scales_precision,
1216-
),
1217-
)
1218-
1219-
def forward(self, input: torch.Tensor) -> torch.Tensor:
1220-
input = input.to(self.precision)
1221-
# padding is removed for perf
1222-
# input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
1223-
return linear_forward_8da4w(
1224-
input,
1225-
self.weight,
1226-
self.scales,
1227-
self.zeros,
1228-
self.out_features,
1229-
self.groupsize,
1230-
self.precision,
1231-
)
1232-
1233-
1234-
#########################################################################
1235-
##### GPTQ #####
1236-
1237-
1238988
class GPTQQuantHandler(QuantHandler):
1239989
"""
1240990
This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class.
@@ -1445,77 +1195,6 @@ def quantized_model(self) -> nn.Module:
14451195
return self.mod
14461196

14471197

1448-
# class Int8DynActInt4WeightGPTQQuantHandler(GPTQQuantHandler):
1449-
# def __init__(
1450-
# self,
1451-
# groupsize=128,
1452-
# inner_k_tiles=8,
1453-
# padding_allowed=True,
1454-
# precision=torch.float32,
1455-
# ):
1456-
1457-
# self.groupsize = groupsize
1458-
# self.inner_k_tiles = inner_k_tiles
1459-
# self.padding_allowed = padding_allowed
1460-
# self.precision = precision
1461-
# self.dyn_quant_func = lambda x: per_token_dynamic_quant(x)
1462-
# n_bit = 4
1463-
# self.get_qparams_func = lambda w: get_group_qparams_symmetric(
1464-
# w, n_bit, groupsize, self.precision
1465-
# )
1466-
# quant_min = -(2 ** (n_bit - 1))
1467-
# quant_max = 2 ** (n_bit - 1) - 1
1468-
# self.quantize_func = lambda w, qparams: torch.ops.quantized_decomposed.quantize_per_channel_group(
1469-
# w, qparams[0], qparams[1], quant_min, quant_max, torch.int8, groupsize
1470-
# )
1471-
# self.dequantize_func = lambda q, qparams: torch.ops.quantized_decomposed.dequantize_per_channel_group(
1472-
# q,
1473-
# qparams[0],
1474-
# qparams[1],
1475-
# quant_min,
1476-
# quant_max,
1477-
# torch.int8,
1478-
# groupsize,
1479-
# self.precision,
1480-
# )
1481-
# self.combine_qparams_list_func = lambda qparams_list: [
1482-
# torch.cat(x, dim=1) for x in zip(*qparams_list)
1483-
# ]
1484-
# # skip unless padding_allowed=True or its correctly sized
1485-
# self.skip_layer_func = lambda linear_weight: not (
1486-
# _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles)
1487-
# or padding_allowed
1488-
# )
1489-
1490-
# # we need to do the padding here, both for q and the qparams if necessary
1491-
# def make_names_and_values_dict_func(q, qparams):
1492-
# k = q.shape[1]
1493-
# new_k = _calc_padded_size_linear_int4(k, groupsize, inner_k_tiles)
1494-
# # how much we need to pad the weight
1495-
# delta_k = new_k - q.shape[1]
1496-
# final_q = F.pad(q, pad=(0, delta_k))
1497-
# scales_and_zeros = pack_scales_and_zeros(*qparams, precision=self.precision)
1498-
# # how many new groups we need for padded weight
1499-
# delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
1500-
# # TODO: split scales and zero_points
1501-
# final_s_and_z = F.pad(
1502-
# scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1
1503-
# )
1504-
# return {"weight": final_q, "scales_and_zeros": final_s_and_z}
1505-
1506-
# self.make_names_and_values_dict_func = make_names_and_values_dict_func
1507-
# super().__init__()
1508-
1509-
# def convert_for_runtime(self, model):
1510-
# replace_linear_8da4w(
1511-
# model,
1512-
# self.groupsize,
1513-
# self.padding_allowed,
1514-
# torch.int8,
1515-
# self.precision,
1516-
# )
1517-
# return model
1518-
15191198
##################################################################
15201199
### WIP: HQQ ###
15211200

0 commit comments

Comments
 (0)