Skip to content

Commit 0ccf509

Browse files
authored
[XNNPACK] resolve ambiguity around 2d affine quantized tensors
Differential Revision: D70719546 Pull Request resolved: #8958
1 parent 157bf91 commit 0ccf509

File tree

4 files changed

+55
-32
lines changed

4 files changed

+55
-32
lines changed

backends/xnnpack/operators/op_dynamic_dequantize_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import XNNGraph
1515
from executorch.backends.xnnpack.utils.quant_utils import (
16+
is_dynamic_qdq,
1617
is_per_channel_group,
1718
is_per_token,
1819
)
@@ -92,7 +93,8 @@ def define_node(
9293
"""
9394
We always define dequantize affine nodes because they are always explicit
9495
"""
95-
if is_per_channel_group(node):
96+
is_dynamic = is_dynamic_qdq(node)
97+
if is_per_channel_group(node) and not is_dynamic:
9698
check_or_raise(
9799
is_param_node(self._exported_program, node.all_input_nodes[0]),
98100
f"Expected quantize affine node with per-token semantics to be used "
@@ -103,7 +105,7 @@ def define_node(
103105
return
104106

105107
check_or_raise(
106-
is_per_token(node),
108+
is_per_token(node) and is_dynamic,
107109
"Expecting Affine Dequantized Op to have per-token semantics",
108110
)
109111
# This must be a per-token affine dequantized node, so let us serialize as such

backends/xnnpack/operators/op_dynamic_quantize_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
XNode,
1919
)
2020
from executorch.backends.xnnpack.utils.quant_utils import (
21+
is_dynamic_qdq,
2122
is_per_channel_group,
2223
is_per_token,
2324
)
@@ -138,13 +139,14 @@ def define_node(
138139
"""
139140
We always define quantize affine nodes because they are always explicit
140141
"""
141-
if is_per_channel_group(node):
142+
is_dynamic = is_dynamic_qdq(node)
143+
if is_per_channel_group(node) and not is_dynamic:
142144
# Affine quantized was recognized as per channel group which means that it should
143145
# be skipped as this means it is used in front of a weight node
144146
return
145147

146148
check_or_raise(
147-
is_per_token(node),
149+
is_per_token(node) and is_dynamic,
148150
"Encountered affine quantized op which does not have per-token semantics",
149151
)
150152
# Treat this node as dynamic per-token quantization

backends/xnnpack/test/ops/test_linear.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -645,31 +645,32 @@ def _test_qd8_per_token_weight_per_channel_group_int4(
645645
bl_sizes = [32, 32, 32, 64]
646646
N_sizes = [2, 17, 92, 128]
647647

648-
for use_bias in [True, False]:
649-
for M, K, bl, N in zip(M_sizes, K_sizes, bl_sizes, N_sizes):
650-
lin_mod = BaseLinear(
651-
in_size=M,
652-
input_channels=K,
653-
output_channels=N,
654-
dtype=dtype,
655-
use_bias=use_bias,
656-
)
648+
for input_rank in range(2, 4):
649+
for use_bias in [True, False]:
650+
for M, K, bl, N in zip(M_sizes, K_sizes, bl_sizes, N_sizes):
651+
lin_mod = BaseLinear(
652+
in_size=M,
653+
input_channels=K,
654+
output_channels=N,
655+
dtype=dtype,
656+
use_bias=use_bias,
657+
)
657658

658-
inputs = lin_mod.get_inputs()
659-
# Half requires slightly higher atol, but if you look at error it is not that bad:
660-
# Difference: max: 0.00140380859375, abs: 0.00140380859375, mean abs error: 0.00042724609375.
661-
# -- Model vs. Reference --
662-
# Numel: 4, 4
663-
# Median: -0.05023193359375, -0.0516357421875
664-
# Mean: 0.2373046875, 0.237060546875
665-
# Max: 1.0078125, 1.0078125
666-
# Min: -0.08465576171875, -0.08441162109375
667-
atol = (
668-
1e-2 if dtype == torch.half else 5e-3
669-
) # TODO(T212995726): Investigate right atol for rand[n] inputs
670-
self._test_groupwise_dq_linear(
671-
lin_mod, inputs, group_size=bl, use_bias=use_bias, atol=atol
672-
)
659+
inputs = lin_mod.get_inputs(rank=input_rank)
660+
# Half requires slightly higher atol, but if you look at error it is not that bad:
661+
# Difference: max: 0.00140380859375, abs: 0.00140380859375, mean abs error: 0.00042724609375.
662+
# -- Model vs. Reference --
663+
# Numel: 4, 4
664+
# Median: -0.05023193359375, -0.0516357421875
665+
# Mean: 0.2373046875, 0.237060546875
666+
# Max: 1.0078125, 1.0078125
667+
# Min: -0.08465576171875, -0.08441162109375
668+
atol = (
669+
1e-2 if dtype == torch.half else 5e-3
670+
) # TODO(T212995726): Investigate right atol for rand[n] inputs
671+
self._test_groupwise_dq_linear(
672+
lin_mod, inputs, group_size=bl, use_bias=use_bias, atol=atol
673+
)
673674

674675
def test_fp16_linear(self):
675676
for use_bias in (True, False):

backends/xnnpack/utils/quant_utils.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,30 @@
4747

4848

4949
def is_dynamic_qdq(node: torch.fx.Node) -> bool:
50-
if node.op != "call_function":
50+
# check has dynamic qdq name
51+
if not (is_quant(node) or is_dequant(node)):
52+
return False
53+
54+
# check scales and zp are dynamically chosen
55+
node_input_args = node.args
56+
if is_affine_qdq(node):
57+
node_input_args = extract_qdq_affine_op_args_for_decomposed_ops(node)
58+
59+
scale = node_input_args[1]
60+
zp = node_input_args[2]
61+
if not (isinstance(scale, torch.fx.Node) and isinstance(zp, torch.fx.Node)):
62+
return False
63+
64+
if not (scale.target == operator.getitem and zp.target == operator.getitem):
65+
return False
66+
67+
scale_choose_qparam = scale.all_input_nodes[0]
68+
zp_choose_qparam = zp.all_input_nodes[0]
69+
70+
if not (is_qparam(scale_choose_qparam) and is_qparam(zp_choose_qparam)):
5171
return False
52-
node_name = format_target_name(node.target.__name__) # pyre-ignore
53-
is_dynamic_affine = is_per_token(node) and not is_per_channel_group(node)
5472

55-
return node_name in _DYNAMIC_OPS or is_dynamic_affine
73+
return True
5674

5775

5876
def is_qparam(node: torch.fx.Node) -> bool:

0 commit comments

Comments
 (0)