@@ -92,9 +92,20 @@ def quantize_model(model: nn.Module, device, quantize_options):
92
92
).quantized_model ()
93
93
elif quantizer == "linear:a8w4dq" :
94
94
linears_quantized = True
95
- model = Int8DynActInt4WeightQuantHandler (
96
- model , device , ** q_kwargs
97
- ).quantized_model ()
95
+ from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
96
+ # Note that Int8DynActInt4WeightQuantizer takes precision as
97
+ # arg, which is used to determine the precision/dtype of the output
98
+ # That is, if dtype=fp32 than this dynamically quantized linear will
99
+ # return output tensor with fp32 dtype.
100
+ # Ideally we make this dynamic such that the output dtype is determined
101
+ # based on the input dtype, instead of having to instantiate quantizer
102
+ # that picks the output dtype.
103
+ # Since this require change in torchao, we leave the current state as is
104
+ # and use the default precision for Int8DynActInt4WeightQuantizer
105
+ # which is fp32.
106
+ assert 'groupsize' in list (q_kwargs .keys ()), f"a8w4dq quantization option must specify groupsize. Specified options { q_kwargs } "
107
+ model = Int8DynActInt4WeightQuantizer (groupsize = q_kwargs ['groupsize' ]
108
+ ).quantize (model )
98
109
elif quantizer == "linear:gptq" :
99
110
linears_quantized = True
100
111
model = WeightOnlyInt4GPTQQuantHandler (
@@ -968,273 +979,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
968
979
969
980
970
981
#########################################################################
971
- ##### Int8 Dynamic Activations 4 Bit Weights #####
972
-
973
-
974
- def prepare_int4_weight_and_scales_and_zeros (weight , groupsize , precision ):
975
- weight_int8 , scales , zeros = group_quantize_tensor_symmetric (
976
- weight ,
977
- n_bit = 4 ,
978
- groupsize = groupsize ,
979
- precision = precision ,
980
- )
981
- # TODO: better API
982
- # weight_int4packed = torch.ops.quantized_decomposed.pack_int4_from_int8(weight_int8)
983
- return weight_int8 , scales , zeros
984
-
985
-
986
- def linear_forward_8da4w (
987
- x , weight_int8 , scales , zeros , out_features , groupsize , precision
988
- ):
989
- x = per_token_dynamic_quant (x )
990
- # TODO: verify and remove following reshape code
991
- # origin_x_size = x.size()
992
- # x = x.reshape(-1, origin_x_size[-1])
993
-
994
- # TODO: better API
995
- # weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed)
996
- n_bit = 4
997
- quant_min = - (2 ** (n_bit - 1 ))
998
- quant_max = 2 ** (n_bit - 1 ) - 1
999
- w_dq = torch .ops .quantized_decomposed .dequantize_per_channel_group (
1000
- weight_int8 ,
1001
- scales ,
1002
- zeros ,
1003
- quant_min ,
1004
- quant_max ,
1005
- torch .int8 ,
1006
- groupsize ,
1007
- precision ,
1008
- )
1009
-
1010
- # x = x.to(torch.float16)
1011
- # w_dq = w_dq.to(torch.float16)
1012
- c = torch .nn .functional .linear (x , w_dq )
1013
-
1014
- # new_shape = origin_x_size[:-1] + (out_features,)
1015
- # c = c.reshape(new_shape)
1016
-
1017
- return c
1018
-
1019
-
1020
- def find_multiple (n : int , * args : Tuple [int ]) -> int :
1021
- k : int = reduce (lambda x , y : x * y // gcd (x , y ), args + (1 ,)) # type: ignore[9]
1022
- if n % k == 0 :
1023
- return n
1024
- return n + k - (n % k )
1025
-
982
+ ##### GPTQ #####
1026
983
1027
984
def _check_linear_int4_k (k , groupsize = 1 ):
1028
985
return k % groupsize == 0
1029
986
1030
987
1031
- def _calc_padded_size_linear_int4 (k , groupsize = 1 ):
1032
- return find_multiple (k , groupsize )
1033
-
1034
-
1035
- def replace_linear_8da4w (
1036
- module ,
1037
- groupsize ,
1038
- padding_allowed ,
1039
- precision ,
1040
- scales_precision ,
1041
- ):
1042
- for name , child in module .named_children ():
1043
- if isinstance (child , nn .Linear ):
1044
- if _check_linear_int4_k (child .in_features , groupsize ) or padding_allowed :
1045
- setattr (
1046
- module ,
1047
- name ,
1048
- Int8DynActInt4WeightLinear (
1049
- child .in_features ,
1050
- child .out_features ,
1051
- bias = False ,
1052
- groupsize = groupsize ,
1053
- precision = precision ,
1054
- scales_precision = scales_precision ,
1055
- ),
1056
- )
1057
- else :
1058
- replace_linear_8da4w (
1059
- child ,
1060
- groupsize ,
1061
- padding_allowed ,
1062
- precision ,
1063
- scales_precision ,
1064
- )
1065
-
1066
-
1067
- class Int8DynActInt4WeightQuantHandler (QuantHandler ):
1068
- def __init__ (
1069
- self ,
1070
- mod ,
1071
- device ,
1072
- * ,
1073
- groupsize = 256 ,
1074
- padding_allowed = False ,
1075
- precision = torch .float32 ,
1076
- scales_precision = torch .float32 ,
1077
- ):
1078
- self .mod = mod
1079
- self .device = device
1080
- self .groupsize = groupsize
1081
- self .padding_allowed = padding_allowed
1082
- self .precision = precision
1083
- self .scales_precision = scales_precision
1084
- # assert groupsize in [32, 64, 128, 256]
1085
-
1086
- @torch .no_grad ()
1087
- def create_quantized_state_dict (self ):
1088
- cur_state_dict = self .mod .state_dict ()
1089
- for fqn , mod in self .mod .named_modules ():
1090
- if isinstance (mod , torch .nn .Linear ):
1091
- assert not mod .bias
1092
- in_features = mod .in_features
1093
- # print("in features:", in_features, " out features:", out_features)
1094
- # assert out_features % 8 == 0, "require out_features % 8 == 0"
1095
- # print(f"linear: {fqn}, in={in_features}, out={out_features}")
1096
-
1097
- assert (
1098
- in_features % self .groupsize == 0
1099
- ), f"require in_features:{ in_features } % self.groupsize:{ self .groupsize } == 0"
1100
-
1101
- weight = mod .weight .data
1102
- """
1103
- if not _check_linear_int4_k(
1104
- in_features, self.groupsize
1105
- ):
1106
- if self.padding_allowed:
1107
- print(
1108
- f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
1109
- )
1110
- padded_in_features = _calc_padded_size_linear_int4(
1111
- in_features, self.groupsize
1112
- )
1113
- weight = F.pad(
1114
- weight, pad=(0, padded_in_features - in_features)
1115
- )
1116
- else:
1117
- raise RuntimeError(
1118
- f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
1119
- + "and that groupsize"
1120
- )
1121
- """
1122
- (
1123
- weight_int4pack ,
1124
- scales ,
1125
- zeros ,
1126
- ) = prepare_int4_weight_and_scales_and_zeros (
1127
- weight .to (self .precision ),
1128
- self .groupsize ,
1129
- self .scales_precision ,
1130
- )
1131
- cur_state_dict [f"{ fqn } .weight" ] = weight_int4pack .to ("cpu" )
1132
- cur_state_dict [f"{ fqn } .scales" ] = scales .to ("cpu" )
1133
- cur_state_dict [f"{ fqn } .zeros" ] = zeros .to ("cpu" )
1134
-
1135
- return cur_state_dict
1136
-
1137
- def convert_for_runtime (self ):
1138
- replace_linear_8da4w (
1139
- self .mod ,
1140
- self .groupsize ,
1141
- self .padding_allowed ,
1142
- self .precision ,
1143
- self .scales_precision ,
1144
- )
1145
- return self .mod
1146
-
1147
- def quantized_model (self ) -> nn .Module :
1148
- model_updated_state_dict = self .create_quantized_state_dict ()
1149
- self .convert_for_runtime ()
1150
- self .mod .load_state_dict (model_updated_state_dict )
1151
- return self .mod
1152
-
1153
-
1154
- class Int8DynActInt4WeightLinear (torch .nn .Module ):
1155
- __constants__ = ["in_features" , "out_features" ]
1156
-
1157
- in_features : int
1158
- out_features : int
1159
- weight : torch .Tensor
1160
-
1161
- """
1162
- This module implements a dynamic quantized linear layer with int4 weight.
1163
- Weights are per channel groupwise quantized. Parameters of importance
1164
- groupsize: the number of elements in each quantized group
1165
- precision: precision of input and output. e.g. torch.float32 means input
1166
- activation is float32 and output is float32.
1167
- scales_precision: precision of per group scale.
1168
- """
1169
-
1170
- def __init__ (
1171
- self ,
1172
- in_features : int ,
1173
- out_features : int ,
1174
- bias = True ,
1175
- device = None ,
1176
- dtype = None ,
1177
- groupsize : int = 256 ,
1178
- precision : torch .dtype = torch .float32 ,
1179
- scales_precision : torch .dtype = torch .float32 ,
1180
- ) -> None :
1181
- super ().__init__ ()
1182
- # always pad if needed since it becomes a noop at runtime if not needed
1183
- # self.origin_in_features = in_features
1184
- assert (
1185
- in_features % groupsize == 0
1186
- ), f"require in_features:{ in_features } % groupsize:{ groupsize } == 0"
1187
- # in_features = _calc_padded_size_linear_int4(
1188
- # in_features, groupsize
1189
- # )
1190
- self .in_features = in_features
1191
- self .out_features = out_features
1192
- assert not bias , "require bias=False"
1193
- self .groupsize = groupsize
1194
- # Precision of the activation which also indicates
1195
- # output precision of the dynamically quantized linear layer
1196
- # that his module represents.
1197
- self .precision = precision
1198
-
1199
- # currently storing unpacked int8 weights
1200
- self .register_buffer (
1201
- "weight" ,
1202
- torch .empty ((out_features , in_features ), dtype = torch .int8 ),
1203
- )
1204
- self .register_buffer (
1205
- "scales" ,
1206
- torch .empty (
1207
- (out_features , in_features // groupsize ),
1208
- dtype = scales_precision ,
1209
- ),
1210
- )
1211
- self .register_buffer (
1212
- "zeros" ,
1213
- torch .empty (
1214
- (out_features , in_features // groupsize ),
1215
- dtype = scales_precision ,
1216
- ),
1217
- )
1218
-
1219
- def forward (self , input : torch .Tensor ) -> torch .Tensor :
1220
- input = input .to (self .precision )
1221
- # padding is removed for perf
1222
- # input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
1223
- return linear_forward_8da4w (
1224
- input ,
1225
- self .weight ,
1226
- self .scales ,
1227
- self .zeros ,
1228
- self .out_features ,
1229
- self .groupsize ,
1230
- self .precision ,
1231
- )
1232
-
1233
-
1234
- #########################################################################
1235
- ##### GPTQ #####
1236
-
1237
-
1238
988
class GPTQQuantHandler (QuantHandler ):
1239
989
"""
1240
990
This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class.
@@ -1445,77 +1195,6 @@ def quantized_model(self) -> nn.Module:
1445
1195
return self .mod
1446
1196
1447
1197
1448
- # class Int8DynActInt4WeightGPTQQuantHandler(GPTQQuantHandler):
1449
- # def __init__(
1450
- # self,
1451
- # groupsize=128,
1452
- # inner_k_tiles=8,
1453
- # padding_allowed=True,
1454
- # precision=torch.float32,
1455
- # ):
1456
-
1457
- # self.groupsize = groupsize
1458
- # self.inner_k_tiles = inner_k_tiles
1459
- # self.padding_allowed = padding_allowed
1460
- # self.precision = precision
1461
- # self.dyn_quant_func = lambda x: per_token_dynamic_quant(x)
1462
- # n_bit = 4
1463
- # self.get_qparams_func = lambda w: get_group_qparams_symmetric(
1464
- # w, n_bit, groupsize, self.precision
1465
- # )
1466
- # quant_min = -(2 ** (n_bit - 1))
1467
- # quant_max = 2 ** (n_bit - 1) - 1
1468
- # self.quantize_func = lambda w, qparams: torch.ops.quantized_decomposed.quantize_per_channel_group(
1469
- # w, qparams[0], qparams[1], quant_min, quant_max, torch.int8, groupsize
1470
- # )
1471
- # self.dequantize_func = lambda q, qparams: torch.ops.quantized_decomposed.dequantize_per_channel_group(
1472
- # q,
1473
- # qparams[0],
1474
- # qparams[1],
1475
- # quant_min,
1476
- # quant_max,
1477
- # torch.int8,
1478
- # groupsize,
1479
- # self.precision,
1480
- # )
1481
- # self.combine_qparams_list_func = lambda qparams_list: [
1482
- # torch.cat(x, dim=1) for x in zip(*qparams_list)
1483
- # ]
1484
- # # skip unless padding_allowed=True or its correctly sized
1485
- # self.skip_layer_func = lambda linear_weight: not (
1486
- # _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles)
1487
- # or padding_allowed
1488
- # )
1489
-
1490
- # # we need to do the padding here, both for q and the qparams if necessary
1491
- # def make_names_and_values_dict_func(q, qparams):
1492
- # k = q.shape[1]
1493
- # new_k = _calc_padded_size_linear_int4(k, groupsize, inner_k_tiles)
1494
- # # how much we need to pad the weight
1495
- # delta_k = new_k - q.shape[1]
1496
- # final_q = F.pad(q, pad=(0, delta_k))
1497
- # scales_and_zeros = pack_scales_and_zeros(*qparams, precision=self.precision)
1498
- # # how many new groups we need for padded weight
1499
- # delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
1500
- # # TODO: split scales and zero_points
1501
- # final_s_and_z = F.pad(
1502
- # scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1
1503
- # )
1504
- # return {"weight": final_q, "scales_and_zeros": final_s_and_z}
1505
-
1506
- # self.make_names_and_values_dict_func = make_names_and_values_dict_func
1507
- # super().__init__()
1508
-
1509
- # def convert_for_runtime(self, model):
1510
- # replace_linear_8da4w(
1511
- # model,
1512
- # self.groupsize,
1513
- # self.padding_allowed,
1514
- # torch.int8,
1515
- # self.precision,
1516
- # )
1517
- # return model
1518
-
1519
1198
##################################################################
1520
1199
### WIP: HQQ ###
1521
1200
0 commit comments