@@ -39,11 +39,12 @@ def __init__(
39
39
from torchao .utils import find_multiple
40
40
41
41
self .origin_in_features = in_features
42
- in_features = find_multiple (in_features , (1024 ,))
42
+ # pyre-ignore[6]: Incompatible parameter type
43
+ in_features = find_multiple (in_features , 1024 )
43
44
45
+ self .use_bias = bias
44
46
self .in_features = in_features
45
47
self .out_features = out_features
46
- assert not bias , "require bias=False"
47
48
self .device = device
48
49
self .groupsize = groupsize
49
50
self .inner_k_tiles = inner_k_tiles
@@ -80,20 +81,28 @@ def __init__(
80
81
device = device ,
81
82
),
82
83
)
84
+ if bias :
85
+ self .register_buffer (
86
+ "bias" ,
87
+ torch .empty ((out_features ,), dtype = torch .float32 , device = device ),
88
+ )
83
89
84
90
def forward (self , input : torch .Tensor ) -> torch .Tensor :
85
91
if self .padding :
86
92
input = F .pad (input , pad = (0 , self .in_features - self .origin_in_features ))
87
93
# The forward method is replaced. In the original implementation, the forward
88
94
# method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom
89
95
# operator is called instead.
90
- return torch .ops .et_vk .linear_weight_int4 (
96
+ r = torch .ops .et_vk .linear_weight_int4 (
91
97
input ,
92
98
self .weight ,
93
99
self .groupsize ,
94
100
self .scales_and_zeros ,
95
101
self .inner_k_tiles ,
96
102
)
103
+ if self .use_bias :
104
+ return r + self .bias
105
+ return r
97
106
98
107
99
108
# This function is coped from torchao.quantization.GPTQ._replace_linear_int4
@@ -128,7 +137,7 @@ def _vk_replace_linear_int4(
128
137
new_linear = linear_class (
129
138
child .in_features ,
130
139
child .out_features ,
131
- bias = False ,
140
+ bias = child . bias is not None ,
132
141
device = child .weight .device ,
133
142
groupsize = groupsize ,
134
143
inner_k_tiles = inner_k_tiles ,
@@ -138,6 +147,9 @@ def _vk_replace_linear_int4(
138
147
if copy_weights and child .weight .device != torch .device ("meta" ):
139
148
# pyre-fixme[16]: `Module` has no attribute `weight`.
140
149
new_linear .weight = child .weight
150
+ if child .bias is not None :
151
+ # pyre-fixme[16]: `Module` has no attribute `bias`.
152
+ new_linear .bias = child .bias
141
153
setattr (module , name , new_linear )
142
154
else :
143
155
_vk_replace_linear_int4 (
@@ -189,7 +201,6 @@ def _create_quantized_state_dict(
189
201
mod .out_features < self .feature_limit
190
202
and mod .in_features < self .feature_limit
191
203
):
192
- assert not mod .bias
193
204
out_features = mod .out_features
194
205
in_features = mod .in_features
195
206
logging .info (f"linear: { fqn } , in={ in_features } , out={ out_features } " )
@@ -210,7 +221,8 @@ def _create_quantized_state_dict(
210
221
logging .warn (
211
222
f"warning: { fqn } is padded to satisfy in_features % 1024 == 0"
212
223
)
213
- padded_in_features = find_multiple (in_features , (1024 ,))
224
+ # pyre-ignore[6]: Incompatible parameter type
225
+ padded_in_features = find_multiple (in_features , 1024 )
214
226
weight = F .pad (
215
227
weight , pad = (0 , padded_in_features - in_features )
216
228
)
0 commit comments