Skip to content

Commit bea3041

Browse files
author
Nathanael See
committed
[ET-VK][int4] patch 4-bit source transformation quantizer to support linear modules with biases
While LLaMa does not have biases, there are some models which will have biases in their linear modules. Add support in the source transform quantizer for biases. Differential Revision: [D69072087](https://our.internmc.facebook.com/intern/diff/D69072087/) [ghstack-poisoned]
1 parent e63c923 commit bea3041

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

backends/vulkan/_passes/int4_weight_only_quantizer.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@ def __init__(
3939
from torchao.utils import find_multiple
4040

4141
self.origin_in_features = in_features
42-
in_features = find_multiple(in_features, (1024,))
42+
# pyre-ignore[6]: Incompatible parameter type
43+
in_features = find_multiple(in_features, 1024)
4344

45+
self.use_bias = bias
4446
self.in_features = in_features
4547
self.out_features = out_features
46-
assert not bias, "require bias=False"
4748
self.device = device
4849
self.groupsize = groupsize
4950
self.inner_k_tiles = inner_k_tiles
@@ -80,20 +81,28 @@ def __init__(
8081
device=device,
8182
),
8283
)
84+
if bias:
85+
self.register_buffer(
86+
"bias",
87+
torch.empty((out_features,), dtype=torch.float32, device=device),
88+
)
8389

8490
def forward(self, input: torch.Tensor) -> torch.Tensor:
8591
if self.padding:
8692
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
8793
# The forward method is replaced. In the original implementation, the forward
8894
# method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom
8995
# operator is called instead.
90-
return torch.ops.et_vk.linear_weight_int4(
96+
r = torch.ops.et_vk.linear_weight_int4(
9197
input,
9298
self.weight,
9399
self.groupsize,
94100
self.scales_and_zeros,
95101
self.inner_k_tiles,
96102
)
103+
if self.use_bias:
104+
return r + self.bias
105+
return r
97106

98107

99108
# This function is coped from torchao.quantization.GPTQ._replace_linear_int4
@@ -128,7 +137,7 @@ def _vk_replace_linear_int4(
128137
new_linear = linear_class(
129138
child.in_features,
130139
child.out_features,
131-
bias=False,
140+
bias=child.bias is not None,
132141
device=child.weight.device,
133142
groupsize=groupsize,
134143
inner_k_tiles=inner_k_tiles,
@@ -138,6 +147,9 @@ def _vk_replace_linear_int4(
138147
if copy_weights and child.weight.device != torch.device("meta"):
139148
# pyre-fixme[16]: `Module` has no attribute `weight`.
140149
new_linear.weight = child.weight
150+
if child.bias is not None:
151+
# pyre-fixme[16]: `Module` has no attribute `bias`.
152+
new_linear.bias = child.bias
141153
setattr(module, name, new_linear)
142154
else:
143155
_vk_replace_linear_int4(
@@ -189,7 +201,6 @@ def _create_quantized_state_dict(
189201
mod.out_features < self.feature_limit
190202
and mod.in_features < self.feature_limit
191203
):
192-
assert not mod.bias
193204
out_features = mod.out_features
194205
in_features = mod.in_features
195206
logging.info(f"linear: {fqn}, in={in_features}, out={out_features}")
@@ -210,7 +221,8 @@ def _create_quantized_state_dict(
210221
logging.warn(
211222
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
212223
)
213-
padded_in_features = find_multiple(in_features, (1024,))
224+
# pyre-ignore[6]: Incompatible parameter type
225+
padded_in_features = find_multiple(in_features, 1024)
214226
weight = F.pad(
215227
weight, pad=(0, padded_in_features - in_features)
216228
)

0 commit comments

Comments
 (0)