Skip to content

Move Linear int4 to qops #537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions qops.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,100 @@ def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor:
return r.view(indices.size() + (-1,))

# r = result_weights.to(dtype=result_scales.dtype).view(list(result_weights.shape[:-1] + (scales.shape[1], -1, )) * result_scales.view(scales.shape[-1] + (scales.shape[1], 1, ))


def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_input_size = input.size()
input = input.reshape(-1, origin_input_size[-1])

if "cuda" in str(input.device):
c = torch.ops.aten._weight_int4pack_mm(
input.to(torch.bfloat16),
weight_int4pack,
groupsize,
scales_and_zeros.to(torch.bfloat16),
).to(
input.dtype
) # cast back to input.dtype
else:
c = torch.ops.aten._weight_int4pack_mm(
input,
weight_int4pack,
groupsize,
scales_and_zeros,
)
new_shape = origin_input_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c


class LinearInt4(torch.nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
scales_and_zeros: torch.Tensor

def __init__(
self,
device: str,
in_features: int,
out_features: int,
bias=True,
dtype=None,
groupsize: int = 128,
inner_k_tiles: int = 8,
) -> None:
super().__init__()
self.padding = not self._check_k(
k=in_features,
groupsize=groupsize,
inner_k_tiles=inner_k_tiles,
)
if self.padding:
self.origin_in_features = in_features
in_features = find_multiple(in_features, 1024)

self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles

assert out_features % 8 == 0, "require out_features % 8 == 0"
assert (
in_features % (inner_k_tiles * 16) == 0
), "require in_features % (innerKTiles * 16) == 0"
self.register_buffer(
"weight",
torch.empty(
(
out_features // 8,
in_features // (inner_k_tiles * 16),
32,
inner_k_tiles // 2,
),
dtype=torch.int32,
device=device,
),
)
self.register_buffer(
"scales_and_zeros",
torch.empty(
(in_features // groupsize, out_features, 2),
dtype=get_precision(),
device=device,
),
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.padding:
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
return linear_int4(
input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)

@classmethod
def _check_k(cls, *, k, groupsize=1, inner_k_tiles=1):
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0

169 changes: 53 additions & 116 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
)
from qops import LinearInt8 as WeightOnlyInt8Linear, QuantizedEmbedding

from qops import (
LinearInt4 as WeightOnlyInt4Linear,
LinearInt8 as WeightOnlyInt8Linear,
QuantizedEmbedding,
)


#########################################################################
### torchchat quantization API ###
Expand Down Expand Up @@ -606,31 +612,6 @@ def _int4_calc_padded_size(k, groupsize=1, innner_k_tiles=1):
return find_multiple(k, 1024)


def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])

if "cuda" in str(x.device):
c = torch.ops.aten._weight_int4pack_mm(
x.to(torch.bfloat16),
weight_int4pack,
groupsize,
scales_and_zeros.to(torch.bfloat16),
).to(
x.dtype
) # cast back to x.dtype
else:
c = torch.ops.aten._weight_int4pack_mm(
x,
weight_int4pack,
groupsize,
scales_and_zeros,
)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c


def replace_linear_int4(
module,
device,
Expand All @@ -640,9 +621,10 @@ def replace_linear_int4(
):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if (
_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles)
or padding_allowed
if padding_allowed or WeightOnlyInt4Linear._check_k(
k=child.in_features,
groupsize=groupsize,
inner_k_tiles=inner_k_tiles,
):
setattr(
module,
Expand Down Expand Up @@ -704,8 +686,10 @@ def create_quantized_state_dict(self):
# print(f"linear: {fqn}, in={in_features}, out={out_features}")

weight = mod.weight.data
if not _check_linear_int4_k(
in_features, self.groupsize, self.inner_k_tiles
if not WeightOnlyInt4Linear._check_k(
k=in_features,
groupsize=self.groupsize,
inner_k_tiles=self.inner_k_tiles,
):
if self.padding_allowed:
print(
Expand Down Expand Up @@ -751,85 +735,23 @@ def quantized_model(self) -> nn.Module:
return self.model_


class WeightOnlyInt4Linear(torch.nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
scales_and_zeros: torch.Tensor

def __init__(
self,
device: str,
in_features: int,
out_features: int,
bias=True,
dtype=None,
groupsize: int = 128,
inner_k_tiles: int = 8,
) -> None:
super().__init__()
self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
if self.padding:
self.origin_in_features = in_features
in_features = find_multiple(in_features, 1024)

self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles

assert out_features % 8 == 0, "require out_features % 8 == 0"
assert (
in_features % (inner_k_tiles * 16) == 0
), "require in_features % (innerKTiles * 16) == 0"
self.register_buffer(
"weight",
torch.empty(
(
out_features // 8,
in_features // (inner_k_tiles * 16),
32,
inner_k_tiles // 2,
),
dtype=torch.int32,
device=device,
),
)
self.register_buffer(
"scales_and_zeros",
torch.empty(
(in_features // groupsize, out_features, 2),
dtype=get_precision(),
device=device,
),
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.padding:
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
return linear_forward_int4(
input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)


#########################################################################
##### GPTQ #####


def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0


class GPTQQuantHandler(QuantHandler):
"""
This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class.
Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement
__init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime.

The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and
create_quantized_state_dict. Here is a description of each function.
"""This class implements a GPTQ QuantHandler that can be used to
apply GPTQ to a model in concert with the GenericGPTQRunner class.
Unlike the base QuantHandler class, the user does not need to
implement the create_quantized_state_dict, instead they have to
reimplement __init__ such that it defines the functions for the
quantization mode. User is expected to reimplement
convert_for_runtime.

The following functions (which must be defined in __init__) are
used to define the quantization mode for both GPTQ and
create_quantized_state_dict. Here is a description of each
function.

get_qparams_func:
A function that calculates the quantization qparams for an input tensor.
Expand All @@ -839,9 +761,11 @@ class GPTQQuantHandler(QuantHandler):
qparams: it can have any format but will need to be handled by the other defined functions below.

quantize_func:
A function that applies quantization to an input tensor. It should be noted
that this function needs to be able to handle quantizing the entire weight tensor, a single group,
or a single column.
A function that applies quantization to an input tensor. It
should be noted that this function needs to be able to handle
quantizing the entire weight tensor, a single group, or a
single column.

Args:
weight: A 2d weight tensor with non-integer dtype.
qparams: the output from get_qparams_func
Expand All @@ -850,9 +774,11 @@ class GPTQQuantHandler(QuantHandler):


dequantize_func:
A function that dequantizes an input quantized weight tensor. It should be noted
that this function needs to be able to handle dequantizing the entire weight tensor, a single group,
or a single column.
A function that dequantizes an input quantized weight
tensor. It should be noted that this function needs to be able
to handle dequantizing the entire weight tensor, a single
group, or a single column.

Args:
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
qparams: the output from get_qparams_func
Expand All @@ -861,6 +787,7 @@ class GPTQQuantHandler(QuantHandler):

combine_qparams_list_func:
A function that combines several qparams into one qparam.

Args:
qparams_list: a list of qparams objects, each obtained by calling get_qparams_func
on a single group from a weight tensor
Expand All @@ -875,13 +802,17 @@ class GPTQQuantHandler(QuantHandler):
skip: boolean indicating whether layer should be skipped

make_names_and_values_dict_func:
A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they
should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here.
A function that prepares the qparams and quantized_weight and
creates a dictionary indicating how they should be inserted
into the state_dict. Generally any packing of the weight and
qparams should be done here.

Args:
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
qparams: the output from get_qparams_func
Returns:
names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the
names_and_values_dict: a dictionary mapping the name of
the parameters of the quantized module to the
corresponding quantized weights and qparams.
"""

Expand Down Expand Up @@ -1026,14 +957,20 @@ def __init__(
]
# skip unless padding_allowed=True or its correctly sized
self.skip_layer_func = lambda linear_weight: not (
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles)
or padding_allowed
padding_allowed
or WeightOnlyInt4Linear._check_k(
k=linear_weight.shape[-1],
groupsize=groupsize,
inner_k_tiles=inner_k_tiles,
)
)

# we need to do the padding here, both for q and the qparams if necessary
def make_names_and_values_dict_func(q, qparams):
k = q.shape[1]
if not _check_linear_int4_k(k, groupsize, inner_k_tiles):
if not WeightOnlyInt4Linear._check_k(
k=k, groupsize=groupsize, inner_k_tiles=inner_k_tiles
):
new_k = find_multiple(k, 1024)
else:
new_k = k
Expand Down