Skip to content

Commit 54a66e5

Browse files
authored
[Misc] Update compressed-tensors WNA16 to support zero-points (#14211)
1 parent 280d62b commit 54a66e5

File tree

6 files changed

+85
-45
lines changed

6 files changed

+85
-45
lines changed

tests/quantization/test_compressed_tensors.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -261,16 +261,23 @@ def check_model(model):
261261

262262
@pytest.mark.parametrize(
263263
"wNa16_args",
264-
[
265-
("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
266-
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
267-
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4),
268-
],
264+
[("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8,
265+
True, False),
266+
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8, True,
267+
False),
268+
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4,
269+
True, False),
270+
("nm-testing/TinyLlama-1.1B-Chat-v1.0-awq-group128-asym256", "group", 128,
271+
8, False, False),
272+
("nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-Channel",
273+
"channel", None, 8, False, False),
274+
("nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder",
275+
"group", 128, 8, False, True)],
269276
)
270277
@pytest.mark.skipif(not current_platform.is_cuda(),
271278
reason="The tests are skipped on non-CUDA platform.")
272279
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
273-
model, strategy, group, pack_factor = wNa16_args
280+
model, strategy, group, pack_factor, symmetric, has_g_idx = wNa16_args
274281
with vllm_runner(model) as llm:
275282

276283
def check_model(model):
@@ -286,6 +293,8 @@ def check_model(model):
286293
if group is None else group)
287294

288295
assert qkv_proj.scheme.pack_factor == pack_factor
296+
assert qkv_proj.scheme.symmetric == symmetric
297+
assert qkv_proj.scheme.has_g_idx == has_g_idx
289298

290299
llm.apply_model(check_model)
291300

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,14 +302,12 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel,
302302
def _is_wNa16_group_channel(self, weight_quant: BaseModel,
303303
input_quant: BaseModel) -> bool:
304304
input_quant_none = input_quant is None
305-
is_symmetric = weight_quant.symmetric
306305
is_channel_group = (
307306
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
308307
or weight_quant.strategy == QuantizationStrategy.GROUP.value)
309308
is_static = not weight_quant.dynamic
310309

311-
return (is_channel_group and input_quant_none and is_symmetric
312-
and is_static)
310+
return (is_channel_group and input_quant_none and is_static)
313311

314312
def _get_scheme_from_parts(
315313
self, weight_quant: BaseModel,
@@ -319,6 +317,7 @@ def _get_scheme_from_parts(
319317
if self._is_wNa16_group_channel(weight_quant, input_quant):
320318
if (self.quant_format == CompressionFormat.marlin_24.value
321319
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
320+
assert weight_quant.symmetric
322321
return CompressedTensorsW4A16Sparse24(
323322
strategy=weight_quant.strategy,
324323
num_bits=weight_quant.num_bits,
@@ -328,6 +327,7 @@ def _get_scheme_from_parts(
328327
return CompressedTensorsWNA16(
329328
num_bits=weight_quant.num_bits,
330329
strategy=weight_quant.strategy,
330+
symmetric=weight_quant.symmetric,
331331
group_size=weight_quant.group_size,
332332
actorder=weight_quant.actorder)
333333

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212
MPLinearLayerConfig, choose_mp_linear_kernel)
1313
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
1414
marlin_repeat_scales_on_all_ranks)
15+
# yapf conflicts with isort for this block
16+
# yapf: disable
1517
from vllm.model_executor.parameter import (BasevLLMParameter,
1618
ChannelQuantScaleParameter,
1719
GroupQuantScaleParameter,
20+
PackedColumnParameter,
1821
PackedvLLMParameter,
1922
RowvLLMParameter)
23+
# yapf: enable
2024
from vllm.scalar_type import scalar_types
2125

2226
logger = init_logger(__name__)
@@ -26,6 +30,7 @@
2630
4: scalar_types.uint4b8,
2731
8: scalar_types.uint8b128
2832
}
33+
WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
2934
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
3035

3136

@@ -36,10 +41,12 @@ def __init__(self,
3641
strategy: str,
3742
num_bits: int,
3843
group_size: Optional[int] = None,
44+
symmetric: Optional[bool] = True,
3945
actorder: Optional[ActivationOrdering] = None):
4046

4147
self.pack_factor = 32 // num_bits
4248
self.strategy = strategy
49+
self.symmetric = symmetric
4350
self.group_size = -1 if group_size is None else group_size
4451
self.has_g_idx = actorder == ActivationOrdering.GROUP
4552

@@ -53,7 +60,9 @@ def __init__(self,
5360
f"Unsupported num_bits = {num_bits}. "
5461
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
5562

56-
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
63+
self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
64+
if not self.symmetric else
65+
WNA16_SUPPORTED_TYPES_MAP[num_bits])
5766

5867
@classmethod
5968
def get_min_capability(cls) -> int:
@@ -75,7 +84,7 @@ def create_weights(self, layer: torch.nn.Module, output_size: int,
7584
weight_type=self.quant_type,
7685
act_type=params_dtype,
7786
group_size=self.group_size,
78-
zero_points=False,
87+
zero_points=not self.symmetric,
7988
has_g_idx=self.has_g_idx
8089
)
8190

@@ -120,13 +129,37 @@ def create_weights(self, layer: torch.nn.Module, output_size: int,
120129
dtype=params_dtype,
121130
)
122131
}
132+
133+
zeros_args = {
134+
"weight_loader":
135+
weight_loader,
136+
"data":
137+
torch.zeros(
138+
output_size_per_partition // self.pack_factor,
139+
scales_and_zp_size,
140+
dtype=torch.int32,
141+
)
142+
}
143+
123144
if not partition_scales:
124145
weight_scale = ChannelQuantScaleParameter(output_dim=0,
125146
**weight_scale_args)
147+
148+
if not self.symmetric:
149+
qzeros = PackedColumnParameter(output_dim=0,
150+
packed_dim=0,
151+
packed_factor=self.pack_factor,
152+
**zeros_args)
126153
else:
127154
weight_scale = GroupQuantScaleParameter(output_dim=0,
128155
input_dim=1,
129156
**weight_scale_args)
157+
if not self.symmetric:
158+
qzeros = PackedvLLMParameter(input_dim=1,
159+
output_dim=0,
160+
packed_dim=0,
161+
packed_factor=self.pack_factor,
162+
**zeros_args)
130163

131164
# A 2D array defining the original shape of the weights
132165
# before packing
@@ -138,6 +171,9 @@ def create_weights(self, layer: torch.nn.Module, output_size: int,
138171
layer.register_parameter("weight_scale", weight_scale)
139172
layer.register_parameter("weight_shape", weight_shape)
140173

174+
if not self.symmetric:
175+
layer.register_parameter("weight_zero_point", qzeros)
176+
141177
# group index (for activation reordering)
142178
if self.has_g_idx:
143179
weight_g_idx = RowvLLMParameter(data=torch.empty(
@@ -151,7 +187,7 @@ def create_weights(self, layer: torch.nn.Module, output_size: int,
151187
self.kernel = kernel_type(mp_linear_kernel_config,
152188
w_q_param_name="weight_packed",
153189
w_s_param_name="weight_scale",
154-
w_zp_param_name=None,
190+
w_zp_param_name="weight_zero_point",
155191
w_gidx_param_name="weight_g_idx")
156192

157193
# Checkpoints are serialized in compressed-tensors format, which is

vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,14 @@ def get_min_capability(cls) -> int:
2626
@classmethod
2727
def can_implement(cls,
2828
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
29+
2930
if c.has_g_idx and\
3031
c.partition_weight_shape[0] != c.full_weight_shape[0]:
3132
return False, "Act reordering currently not supported by Machete, "\
3233
"when the input features are partitioned across "\
3334
"devices"
34-
3535
if c.zero_points:
36-
return False, "Zero points currently not supported by "\
37-
" Compressed Tensors + Machete. (Kernel supports it"\
38-
" but CompressedTensorsWNA16 does not so support has"\
39-
" not been added to MacheteWNA16Kernel yet"
36+
return False, "Zero points currently not supported by Machete"
4037

4138
if c.weight_type not in query_machete_supported_quant_types(
4239
c.zero_points):

vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
1010
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
1111
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
12-
query_marlin_supported_quant_types)
12+
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
1313
from vllm.model_executor.parameter import (BasevLLMParameter,
1414
permute_param_layout_)
1515

@@ -25,10 +25,6 @@ def get_min_capability(cls) -> int:
2525
@classmethod
2626
def can_implement(cls,
2727
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
28-
if c.zero_points:
29-
return False, "Zero points currently not supported by "\
30-
" MarlinLinearKernel. Will be added when AWQMarlin "\
31-
"is migrated over to using MPLinearKernel backend"
3228

3329
quant_types = query_marlin_supported_quant_types(c.zero_points)
3430
if c.weight_type not in quant_types:
@@ -67,28 +63,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
6763
if self.w_zp_name is None:
6864
self.w_zp_name = "w_zp"
6965

70-
if c.has_g_idx:
71-
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
72-
getattr(layer, self.w_gidx_name))
73-
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
74-
layer.g_idx_sort_indices = g_idx_sort_indices
75-
else:
76-
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
77-
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
78-
79-
if c.zero_points:
80-
pass
81-
# TODO (lucas): add the following when AWQMarlin is migrated over to
82-
# using MPLinearKernel backend
83-
# self._transform_param(layer, self.w_zp_name, lambda x: \
84-
# marlin_zero_points(
85-
# x,
86-
# size_k=c.partition_weight_shape[0],
87-
# size_n=c.partition_weight_shape[1],
88-
# num_bits=c.weight_type.size_bits))
89-
else:
90-
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
91-
9266
def transform_w_q(x):
9367
assert isinstance(x, BasevLLMParameter)
9468
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
@@ -108,6 +82,28 @@ def transform_w_s(x):
10882
group_size=c.group_size)
10983
return x
11084

85+
if c.has_g_idx:
86+
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
87+
getattr(layer, self.w_gidx_name))
88+
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
89+
layer.g_idx_sort_indices = g_idx_sort_indices
90+
else:
91+
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
92+
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
93+
94+
if c.zero_points:
95+
grouped_k = (c.partition_weight_shape[0] //
96+
c.group_size if c.group_size != -1 else 1)
97+
self._transform_param(layer, self.w_zp_name, lambda x: \
98+
marlin_zero_points(
99+
unpack_cols(x.t(), c.weight_type.size_bits,
100+
grouped_k,
101+
c.partition_weight_shape[1]),
102+
size_k=grouped_k,
103+
size_n=c.partition_weight_shape[1],
104+
num_bits=c.weight_type.size_bits))
105+
else:
106+
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
111107
self._transform_param(layer, self.w_q_name, transform_w_q)
112108
self._transform_param(layer, self.w_s_name, transform_w_s)
113109

@@ -131,5 +127,6 @@ def apply_weights(self,
131127
wtype=c.weight_type,
132128
input_size_per_partition=c.partition_weight_shape[0],
133129
output_size_per_partition=c.partition_weight_shape[1],
130+
has_zp=self.config.zero_points,
134131
is_k_full=self.is_k_full,
135132
bias=bias)

vllm/model_executor/layers/quantization/utils/marlin_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def apply_gptq_marlin_linear(
332332
wtype: ScalarType,
333333
output_size_per_partition: int,
334334
input_size_per_partition: int,
335+
has_zp: bool,
335336
is_k_full: bool,
336337
bias: Optional[torch.Tensor] = None,
337338
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
@@ -356,8 +357,8 @@ def apply_gptq_marlin_linear(
356357
size_n=output_size_per_partition,
357358
size_k=input_size_per_partition,
358359
is_k_full=is_k_full,
359-
has_zp=False,
360360
use_atomic_add=use_atomic_add,
361+
has_zp=has_zp,
361362
use_fp32_reduce=use_fp32_reduce,
362363
is_zp_float=False)
363364

0 commit comments

Comments
 (0)