Skip to content

Commit 43d3037

Browse files
author
Nathanael See
committed
[ET-VK][int4] patch 4-bit source transformation quantizer to support linear modules with biases
Pull Request resolved: #8224 While LLaMa does not have biases, there are some models which will have biases in their linear modules. Add support in the source transform quantizer for biases. ghstack-source-id: 264874723 @exported-using-ghexport Differential Revision: [D69072087](https://our.internmc.facebook.com/intern/diff/D69072087/)
1 parent e63c923 commit 43d3037

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

backends/vulkan/_passes/int4_weight_only_quantizer.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@ def __init__(
3939
from torchao.utils import find_multiple
4040

4141
self.origin_in_features = in_features
42-
in_features = find_multiple(in_features, (1024,))
42+
# pyre-ignore[6]: Incompatible parameter type
43+
in_features = find_multiple(in_features, 1024)
4344

45+
self.use_bias = bias
4446
self.in_features = in_features
4547
self.out_features = out_features
46-
assert not bias, "require bias=False"
4748
self.device = device
4849
self.groupsize = groupsize
4950
self.inner_k_tiles = inner_k_tiles
@@ -80,20 +81,28 @@ def __init__(
8081
device=device,
8182
),
8283
)
84+
if bias:
85+
self.register_buffer(
86+
"bias",
87+
torch.empty((out_features,), dtype=torch.float32, device=device),
88+
)
8389

8490
def forward(self, input: torch.Tensor) -> torch.Tensor:
8591
if self.padding:
8692
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
8793
# The forward method is replaced. In the original implementation, the forward
8894
# method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom
8995
# operator is called instead.
90-
return torch.ops.et_vk.linear_weight_int4(
96+
r = torch.ops.et_vk.linear_weight_int4(
9197
input,
9298
self.weight,
9399
self.groupsize,
94100
self.scales_and_zeros,
95101
self.inner_k_tiles,
96102
)
103+
if self.use_bias:
104+
return r + self.bias
105+
return r
97106

98107

99108
# This function is coped from torchao.quantization.GPTQ._replace_linear_int4
@@ -128,7 +137,7 @@ def _vk_replace_linear_int4(
128137
new_linear = linear_class(
129138
child.in_features,
130139
child.out_features,
131-
bias=False,
140+
bias=child.bias is not None,
132141
device=child.weight.device,
133142
groupsize=groupsize,
134143
inner_k_tiles=inner_k_tiles,
@@ -138,6 +147,9 @@ def _vk_replace_linear_int4(
138147
if copy_weights and child.weight.device != torch.device("meta"):
139148
# pyre-fixme[16]: `Module` has no attribute `weight`.
140149
new_linear.weight = child.weight
150+
if child.bias is not None:
151+
# pyre-fixme[16]: `Module` has no attribute `bias`.
152+
new_linear.bias = child.bias
141153
setattr(module, name, new_linear)
142154
else:
143155
_vk_replace_linear_int4(
@@ -189,7 +201,6 @@ def _create_quantized_state_dict(
189201
mod.out_features < self.feature_limit
190202
and mod.in_features < self.feature_limit
191203
):
192-
assert not mod.bias
193204
out_features = mod.out_features
194205
in_features = mod.in_features
195206
logging.info(f"linear: {fqn}, in={in_features}, out={out_features}")
@@ -210,7 +221,8 @@ def _create_quantized_state_dict(
210221
logging.warn(
211222
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
212223
)
213-
padded_in_features = find_multiple(in_features, (1024,))
224+
# pyre-ignore[6]: Incompatible parameter type
225+
padded_in_features = find_multiple(in_features, 1024)
214226
weight = F.pad(
215227
weight, pad=(0, padded_in_features - in_features)
216228
)

0 commit comments

Comments
 (0)