Skip to content

Commit ec56da8

Browse files
authored
Check group size is divisible by 32
Differential Revision: D66131456 Pull Request resolved: #6941
1 parent 3be3b92 commit ec56da8

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -309,27 +309,30 @@ def _check_per_channel_group_params(
309309
num_groups = cast(torch.Tensor, quant_params.scale).shape[1]
310310
assert (
311311
quant_params.axis == 0
312-
), "For per_channel_group quant, axis must be 0, but got {axis}"
312+
), f"For per_channel_group quant, axis must be 0, but got {quant_params.axis}"
313313
assert (
314314
len(dims) == 2
315-
), "For per_channel_group quant, expecting linear weights to be 2d, but got {len(dims)}"
315+
), f"For per_channel_group quant, expecting linear weights to be 2d, but got {len(dims)}"
316316
assert (
317317
num_groups > 0 and quant_params.group_size > 0
318-
), "For per_channel_group quant, num_groups and group_size must be > 0, but got num_groups: {num_groups}, group_size: {quant_params.group_size}"
318+
), f"For per_channel_group quant, num_groups and group_size must be > 0, but got num_groups: {num_groups}, group_size: {quant_params.group_size}"
319319
output_channels = dims[quant_params.axis]
320320
input_channels = dims[quant_params.axis ^ 1]
321+
assert (
322+
quant_params.group_size % 32 == 0
323+
), f"Delegation to XNNPACK requires group_size to be a multiple of 32, but got {quant_params.group_size}"
321324
assert (
322325
output_channels == cast(torch.Tensor, quant_params.scale).shape[0]
323-
), "For per_channel_group quant, expecting output channels to match scale.shape[0], gut got: {output_channels}, scale.shape[0]: {quant_params.scale.shape[0]}"
326+
), f"For per_channel_group quant, expecting output channels to match scale.shape[0], gut got: {output_channels}, scale.shape[0]: {quant_params.scale.shape[0]}"
324327
assert (
325328
input_channels % num_groups == 0
326-
), "For per_channel_group quant, expecting input channels to be divisible by num_groups, but got ic: {input_channels}, num_groups: {num_groups}"
329+
), f"For per_channel_group quant, expecting input channels to be divisible by num_groups, but got ic: {input_channels}, num_groups: {num_groups}"
327330
assert (
328331
input_channels % quant_params.group_size == 0
329-
), "For per_channel_group quant, expecting input channels to be divisible by group_size, but got ic: {input_channels}, group_size: {quant_params.group_size}"
332+
), f"For per_channel_group quant, expecting input channels to be divisible by group_size, but got ic: {input_channels}, group_size: {quant_params.group_size}"
330333
assert (
331334
input_channels / quant_params.group_size == num_groups
332-
), "For per_channel_group quant, expecting input channels // group_size == num_groups, but got ic: {input_channels}, group_size: {quant_params.group_size}, num_groups: {num_groups}"
335+
), f"For per_channel_group quant, expecting input channels // group_size == num_groups, but got ic: {input_channels}, group_size: {quant_params.group_size}, num_groups: {num_groups}"
333336

334337
# For now group quantization is only supported for 4b weights
335338
assert quant_params.is_qc4w, "Only 4b group quantization is supported"

backends/xnnpack/test/ops/linear.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,28 @@ def test_qd8_fp16_per_token_weight_per_channel_group_int4(self):
458458
lin_mod, inputs, group_size=bl, use_bias=use_bias, atol=1e-2
459459
)
460460

461+
@unittest.skipIf(
462+
not torchao_installed, "Per Channel Group Quantization Required TorchAO"
463+
)
464+
def test_qd8_fp32_per_token_groupwise_unsupported_groupsize(self):
465+
# groupsize must be multiple of 32
466+
lin_mod = BaseLinear(
467+
in_size=1,
468+
input_channels=60,
469+
output_channels=60,
470+
dtype=torch.float32,
471+
use_bias=True,
472+
)
473+
inputs = lin_mod.get_inputs()
474+
475+
with self.assertRaisesRegex(
476+
AssertionError,
477+
"Delegation to XNNPACK requires group_size to be a multiple of 32, but got 30",
478+
):
479+
self._test_groupwise_dq_linear(
480+
lin_mod, inputs, group_size=30, use_bias=False, atol=1e-2
481+
)
482+
461483
def _test_linear(
462484
self,
463485
make_module,

0 commit comments

Comments
 (0)