Skip to content

Commit 0a6b599

Browse files
committed
[XNNPACK] resolve ambiguity around 2d affine quantized tensors
1 parent e37129d commit 0a6b599

File tree

2 files changed

+30
-25
lines changed

2 files changed

+30
-25
lines changed

backends/xnnpack/test/ops/test_linear.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -645,31 +645,32 @@ def _test_qd8_per_token_weight_per_channel_group_int4(
645645
bl_sizes = [32, 32, 32, 64]
646646
N_sizes = [2, 17, 92, 128]
647647

648-
for use_bias in [True, False]:
649-
for M, K, bl, N in zip(M_sizes, K_sizes, bl_sizes, N_sizes):
650-
lin_mod = BaseLinear(
651-
in_size=M,
652-
input_channels=K,
653-
output_channels=N,
654-
dtype=dtype,
655-
use_bias=use_bias,
656-
)
648+
for input_rank in range(2, 4):
649+
for use_bias in [True, False]:
650+
for M, K, bl, N in zip(M_sizes, K_sizes, bl_sizes, N_sizes):
651+
lin_mod = BaseLinear(
652+
in_size=M,
653+
input_channels=K,
654+
output_channels=N,
655+
dtype=dtype,
656+
use_bias=use_bias,
657+
)
657658

658-
inputs = lin_mod.get_inputs()
659-
# Half requires slightly higher atol, but if you look at error it is not that bad:
660-
# Difference: max: 0.00140380859375, abs: 0.00140380859375, mean abs error: 0.00042724609375.
661-
# -- Model vs. Reference --
662-
# Numel: 4, 4
663-
# Median: -0.05023193359375, -0.0516357421875
664-
# Mean: 0.2373046875, 0.237060546875
665-
# Max: 1.0078125, 1.0078125
666-
# Min: -0.08465576171875, -0.08441162109375
667-
atol = (
668-
1e-2 if dtype == torch.half else 5e-3
669-
) # TODO(T212995726): Investigate right atol for rand[n] inputs
670-
self._test_groupwise_dq_linear(
671-
lin_mod, inputs, group_size=bl, use_bias=use_bias, atol=atol
672-
)
659+
inputs = lin_mod.get_inputs(rank=input_rank)
660+
# Half requires slightly higher atol, but if you look at error it is not that bad:
661+
# Difference: max: 0.00140380859375, abs: 0.00140380859375, mean abs error: 0.00042724609375.
662+
# -- Model vs. Reference --
663+
# Numel: 4, 4
664+
# Median: -0.05023193359375, -0.0516357421875
665+
# Mean: 0.2373046875, 0.237060546875
666+
# Max: 1.0078125, 1.0078125
667+
# Min: -0.08465576171875, -0.08441162109375
668+
atol = (
669+
1e-2 if dtype == torch.half else 5e-3
670+
) # TODO(T212995726): Investigate right atol for rand[n] inputs
671+
self._test_groupwise_dq_linear(
672+
lin_mod, inputs, group_size=bl, use_bias=use_bias, atol=atol
673+
)
673674

674675
def test_fp16_linear(self):
675676
for use_bias in (True, False):

backends/xnnpack/utils/quant_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def is_dynamic_qdq(node: torch.fx.Node) -> bool:
5050
if node.op != "call_function":
5151
return False
5252
node_name = format_target_name(node.target.__name__) # pyre-ignore
53-
is_dynamic_affine = is_per_token(node) and not is_per_channel_group(node)
53+
is_dynamic_affine = is_per_token(node)
5454

5555
return node_name in _DYNAMIC_OPS or is_dynamic_affine
5656

@@ -129,6 +129,9 @@ def is_per_token(node: torch.fx.Node):
129129

130130
flag &= block_size[-1] == input_val.shape[-1]
131131
flag &= scale_val.numel() == scale_numel_expected
132+
scale_node = node.all_input_nodes[1]
133+
# per token must have dynamically chosen scale
134+
flag &= scale_node.target == operator.getitem
132135
return flag
133136

134137
return False
@@ -149,6 +152,7 @@ def is_per_channel_group(node: torch.fx.Node):
149152
scale_numel = list(accumulate(scale_val.shape, operator.mul))[-1]
150153
input_numel = list(accumulate(input_val.shape, operator.mul))[-1]
151154
flag &= input_numel == group_size * scale_numel
155+
flag &= not is_per_token(node)
152156
return flag
153157

154158
return False

0 commit comments

Comments
 (0)