Skip to content

Commit 32eb691

Browse files
mikekgfbmalfet
authored andcommitted
Move Linear int4 to qops (#537)
* move int8 linear class and function into qops.py * move Quantized Embedding to qops.py * move int4 linear to qops
1 parent 5044874 commit 32eb691

File tree

2 files changed

+150
-116
lines changed

2 files changed

+150
-116
lines changed

qops.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,100 @@ def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
209209
return r.view(indices.size() + (-1,))
210210

211211
# r = result_weights.to(dtype=result_scales.dtype).view(list(result_weights.shape[:-1] + (scales.shape[1], -1, )) * result_scales.view(scales.shape[-1] + (scales.shape[1], 1, ))
212+
213+
214+
def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsize):
215+
origin_input_size = input.size()
216+
input = input.reshape(-1, origin_input_size[-1])
217+
218+
if "cuda" in str(input.device):
219+
c = torch.ops.aten._weight_int4pack_mm(
220+
input.to(torch.bfloat16),
221+
weight_int4pack,
222+
groupsize,
223+
scales_and_zeros.to(torch.bfloat16),
224+
).to(
225+
input.dtype
226+
) # cast back to input.dtype
227+
else:
228+
c = torch.ops.aten._weight_int4pack_mm(
229+
input,
230+
weight_int4pack,
231+
groupsize,
232+
scales_and_zeros,
233+
)
234+
new_shape = origin_input_size[:-1] + (out_features,)
235+
c = c.reshape(new_shape)
236+
return c
237+
238+
239+
class LinearInt4(torch.nn.Module):
240+
__constants__ = ["in_features", "out_features"]
241+
in_features: int
242+
out_features: int
243+
weight: torch.Tensor
244+
scales_and_zeros: torch.Tensor
245+
246+
def __init__(
247+
self,
248+
device: str,
249+
in_features: int,
250+
out_features: int,
251+
bias=True,
252+
dtype=None,
253+
groupsize: int = 128,
254+
inner_k_tiles: int = 8,
255+
) -> None:
256+
super().__init__()
257+
self.padding = not self._check_k(
258+
k=in_features,
259+
groupsize=groupsize,
260+
inner_k_tiles=inner_k_tiles,
261+
)
262+
if self.padding:
263+
self.origin_in_features = in_features
264+
in_features = find_multiple(in_features, 1024)
265+
266+
self.in_features = in_features
267+
self.out_features = out_features
268+
assert not bias, "require bias=False"
269+
self.groupsize = groupsize
270+
self.inner_k_tiles = inner_k_tiles
271+
272+
assert out_features % 8 == 0, "require out_features % 8 == 0"
273+
assert (
274+
in_features % (inner_k_tiles * 16) == 0
275+
), "require in_features % (innerKTiles * 16) == 0"
276+
self.register_buffer(
277+
"weight",
278+
torch.empty(
279+
(
280+
out_features // 8,
281+
in_features // (inner_k_tiles * 16),
282+
32,
283+
inner_k_tiles // 2,
284+
),
285+
dtype=torch.int32,
286+
device=device,
287+
),
288+
)
289+
self.register_buffer(
290+
"scales_and_zeros",
291+
torch.empty(
292+
(in_features // groupsize, out_features, 2),
293+
dtype=get_precision(),
294+
device=device,
295+
),
296+
)
297+
298+
def forward(self, input: torch.Tensor) -> torch.Tensor:
299+
if self.padding:
300+
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
301+
return linear_int4(
302+
input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
303+
)
304+
305+
@classmethod
306+
def _check_k(cls, *, k, groupsize=1, inner_k_tiles=1):
307+
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
308+

quantize.py

Lines changed: 53 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@
2525
)
2626
from qops import LinearInt8 as WeightOnlyInt8Linear, QuantizedEmbedding
2727

28+
from qops import (
29+
LinearInt4 as WeightOnlyInt4Linear,
30+
LinearInt8 as WeightOnlyInt8Linear,
31+
QuantizedEmbedding,
32+
)
33+
2834

2935
#########################################################################
3036
### torchchat quantization API ###
@@ -606,31 +612,6 @@ def _int4_calc_padded_size(k, groupsize=1, innner_k_tiles=1):
606612
return find_multiple(k, 1024)
607613

608614

609-
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
610-
origin_x_size = x.size()
611-
x = x.reshape(-1, origin_x_size[-1])
612-
613-
if "cuda" in str(x.device):
614-
c = torch.ops.aten._weight_int4pack_mm(
615-
x.to(torch.bfloat16),
616-
weight_int4pack,
617-
groupsize,
618-
scales_and_zeros.to(torch.bfloat16),
619-
).to(
620-
x.dtype
621-
) # cast back to x.dtype
622-
else:
623-
c = torch.ops.aten._weight_int4pack_mm(
624-
x,
625-
weight_int4pack,
626-
groupsize,
627-
scales_and_zeros,
628-
)
629-
new_shape = origin_x_size[:-1] + (out_features,)
630-
c = c.reshape(new_shape)
631-
return c
632-
633-
634615
def replace_linear_int4(
635616
module,
636617
device,
@@ -640,9 +621,10 @@ def replace_linear_int4(
640621
):
641622
for name, child in module.named_children():
642623
if isinstance(child, nn.Linear):
643-
if (
644-
_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles)
645-
or padding_allowed
624+
if padding_allowed or WeightOnlyInt4Linear._check_k(
625+
k=child.in_features,
626+
groupsize=groupsize,
627+
inner_k_tiles=inner_k_tiles,
646628
):
647629
setattr(
648630
module,
@@ -704,8 +686,10 @@ def create_quantized_state_dict(self):
704686
# print(f"linear: {fqn}, in={in_features}, out={out_features}")
705687

706688
weight = mod.weight.data
707-
if not _check_linear_int4_k(
708-
in_features, self.groupsize, self.inner_k_tiles
689+
if not WeightOnlyInt4Linear._check_k(
690+
k=in_features,
691+
groupsize=self.groupsize,
692+
inner_k_tiles=self.inner_k_tiles,
709693
):
710694
if self.padding_allowed:
711695
print(
@@ -751,85 +735,23 @@ def quantized_model(self) -> nn.Module:
751735
return self.model_
752736

753737

754-
class WeightOnlyInt4Linear(torch.nn.Module):
755-
__constants__ = ["in_features", "out_features"]
756-
in_features: int
757-
out_features: int
758-
weight: torch.Tensor
759-
scales_and_zeros: torch.Tensor
760-
761-
def __init__(
762-
self,
763-
device: str,
764-
in_features: int,
765-
out_features: int,
766-
bias=True,
767-
dtype=None,
768-
groupsize: int = 128,
769-
inner_k_tiles: int = 8,
770-
) -> None:
771-
super().__init__()
772-
self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
773-
if self.padding:
774-
self.origin_in_features = in_features
775-
in_features = find_multiple(in_features, 1024)
776-
777-
self.in_features = in_features
778-
self.out_features = out_features
779-
assert not bias, "require bias=False"
780-
self.groupsize = groupsize
781-
self.inner_k_tiles = inner_k_tiles
782-
783-
assert out_features % 8 == 0, "require out_features % 8 == 0"
784-
assert (
785-
in_features % (inner_k_tiles * 16) == 0
786-
), "require in_features % (innerKTiles * 16) == 0"
787-
self.register_buffer(
788-
"weight",
789-
torch.empty(
790-
(
791-
out_features // 8,
792-
in_features // (inner_k_tiles * 16),
793-
32,
794-
inner_k_tiles // 2,
795-
),
796-
dtype=torch.int32,
797-
device=device,
798-
),
799-
)
800-
self.register_buffer(
801-
"scales_and_zeros",
802-
torch.empty(
803-
(in_features // groupsize, out_features, 2),
804-
dtype=get_precision(),
805-
device=device,
806-
),
807-
)
808-
809-
def forward(self, input: torch.Tensor) -> torch.Tensor:
810-
if self.padding:
811-
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
812-
return linear_forward_int4(
813-
input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
814-
)
815-
816-
817738
#########################################################################
818739
##### GPTQ #####
819740

820741

821-
def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
822-
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
823-
824-
825742
class GPTQQuantHandler(QuantHandler):
826-
"""
827-
This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class.
828-
Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement
829-
__init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime.
830-
831-
The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and
832-
create_quantized_state_dict. Here is a description of each function.
743+
"""This class implements a GPTQ QuantHandler that can be used to
744+
apply GPTQ to a model in concert with the GenericGPTQRunner class.
745+
Unlike the base QuantHandler class, the user does not need to
746+
implement the create_quantized_state_dict, instead they have to
747+
reimplement __init__ such that it defines the functions for the
748+
quantization mode. User is expected to reimplement
749+
convert_for_runtime.
750+
751+
The following functions (which must be defined in __init__) are
752+
used to define the quantization mode for both GPTQ and
753+
create_quantized_state_dict. Here is a description of each
754+
function.
833755
834756
get_qparams_func:
835757
A function that calculates the quantization qparams for an input tensor.
@@ -839,9 +761,11 @@ class GPTQQuantHandler(QuantHandler):
839761
qparams: it can have any format but will need to be handled by the other defined functions below.
840762
841763
quantize_func:
842-
A function that applies quantization to an input tensor. It should be noted
843-
that this function needs to be able to handle quantizing the entire weight tensor, a single group,
844-
or a single column.
764+
A function that applies quantization to an input tensor. It
765+
should be noted that this function needs to be able to handle
766+
quantizing the entire weight tensor, a single group, or a
767+
single column.
768+
845769
Args:
846770
weight: A 2d weight tensor with non-integer dtype.
847771
qparams: the output from get_qparams_func
@@ -850,9 +774,11 @@ class GPTQQuantHandler(QuantHandler):
850774
851775
852776
dequantize_func:
853-
A function that dequantizes an input quantized weight tensor. It should be noted
854-
that this function needs to be able to handle dequantizing the entire weight tensor, a single group,
855-
or a single column.
777+
A function that dequantizes an input quantized weight
778+
tensor. It should be noted that this function needs to be able
779+
to handle dequantizing the entire weight tensor, a single
780+
group, or a single column.
781+
856782
Args:
857783
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
858784
qparams: the output from get_qparams_func
@@ -861,6 +787,7 @@ class GPTQQuantHandler(QuantHandler):
861787
862788
combine_qparams_list_func:
863789
A function that combines several qparams into one qparam.
790+
864791
Args:
865792
qparams_list: a list of qparams objects, each obtained by calling get_qparams_func
866793
on a single group from a weight tensor
@@ -875,13 +802,17 @@ class GPTQQuantHandler(QuantHandler):
875802
skip: boolean indicating whether layer should be skipped
876803
877804
make_names_and_values_dict_func:
878-
A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they
879-
should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here.
805+
A function that prepares the qparams and quantized_weight and
806+
creates a dictionary indicating how they should be inserted
807+
into the state_dict. Generally any packing of the weight and
808+
qparams should be done here.
809+
880810
Args:
881811
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
882812
qparams: the output from get_qparams_func
883813
Returns:
884-
names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the
814+
names_and_values_dict: a dictionary mapping the name of
815+
the parameters of the quantized module to the
885816
corresponding quantized weights and qparams.
886817
"""
887818

@@ -1026,14 +957,20 @@ def __init__(
1026957
]
1027958
# skip unless padding_allowed=True or its correctly sized
1028959
self.skip_layer_func = lambda linear_weight: not (
1029-
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles)
1030-
or padding_allowed
960+
padding_allowed
961+
or WeightOnlyInt4Linear._check_k(
962+
k=linear_weight.shape[-1],
963+
groupsize=groupsize,
964+
inner_k_tiles=inner_k_tiles,
965+
)
1031966
)
1032967

1033968
# we need to do the padding here, both for q and the qparams if necessary
1034969
def make_names_and_values_dict_func(q, qparams):
1035970
k = q.shape[1]
1036-
if not _check_linear_int4_k(k, groupsize, inner_k_tiles):
971+
if not WeightOnlyInt4Linear._check_k(
972+
k=k, groupsize=groupsize, inner_k_tiles=inner_k_tiles
973+
):
1037974
new_k = find_multiple(k, 1024)
1038975
else:
1039976
new_k = k

0 commit comments

Comments
 (0)