Skip to content

Commit 884d16d

Browse files
authored
Allow partitioning quantized linear for FP32-only partition
Differential Revision: D67011716 Pull Request resolved: #7284
1 parent 18142f7 commit 884d16d

File tree

2 files changed

+178
-6
lines changed

2 files changed

+178
-6
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,22 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:
9393

9494
return ConfigPrecisionType.STATIC_QUANT
9595

96+
def _overwrite_precision(self, node: torch.fx.Node):
97+
precision = self._detect_precision(node)
98+
if precision not in self.enabled_precision_types:
99+
# detected precision is not enabled, lets try to partition it as fp32
100+
if self.enabled_precision_types == [ConfigPrecisionType.FP32]:
101+
# if only fp32 is enabled, then we can still partition fp32 gemms
102+
# even with in a quantized graph
103+
if precision in [
104+
ConfigPrecisionType.STATIC_QUANT,
105+
ConfigPrecisionType.DYNAMIC_QUANT,
106+
]:
107+
precision = ConfigPrecisionType.FP32
108+
logging.info(f"Overwriting precision, partitioning {node} as FP32")
109+
return True, precision
110+
return False, precision
111+
96112
def get_deps(
97113
self,
98114
node: torch.fx.Node,
@@ -107,7 +123,7 @@ def get_deps(
107123
if precision not in self.supported_precision_types():
108124
# detected precision but it is either disabled or not supported
109125
return (False, [])
110-
126+
_, precision = self._overwrite_precision(node)
111127
valid_bias, bias_deps = self._get_bias_deps(node, ep, precision)
112128
valid_weight, weight_deps = self._get_weight_deps(node, ep, precision)
113129
valid_act, act_deps = self._get_act_deps(node, ep, precision)
@@ -193,7 +209,7 @@ def _get_bias_deps(
193209
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
194210
) -> Tuple[bool, List[torch.fx.Node]]:
195211
gemm_deps = []
196-
if len(node.all_input_nodes) > 2 and self.bias_idx:
212+
if len(node.all_input_nodes) > 2 and self.bias_idx is not None:
197213
bias_node = get_input_node(node, self.bias_idx)
198214
if bias_node:
199215
if not is_param_node(ep, bias_node):
@@ -266,7 +282,14 @@ def _get_weight_deps(
266282
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
267283
) -> Tuple[bool, List[torch.fx.Node]]:
268284
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
269-
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
285+
# if force fp32_dynamic_linear is enabled, then we
286+
# do not partition the weight node
287+
return (True, [])
288+
289+
# Since we are in Linear, we may assume that the weights are indeed static.
290+
overwritten_linear_precision, new_precision = self._overwrite_precision(node)
291+
if new_precision == ConfigPrecisionType.FP32 and overwritten_linear_precision:
292+
# if overwriting quantized precision to fp32, then we
270293
# do not partition the weight node
271294
return (True, [])
272295

backends/xnnpack/test/ops/test_linear.py

Lines changed: 152 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@
88
import unittest
99

1010
from itertools import product
11-
from typing import Optional, Tuple
11+
from typing import Callable, Dict, List, Optional, Tuple
1212

1313
import torch
1414
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
1515
ConfigPrecisionType,
1616
)
17-
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
18-
17+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
18+
XnnpackFloatingPointPartitioner,
19+
XnnpackPartitioner,
20+
)
1921
from executorch.backends.xnnpack.test.tester import Quantize, Tester
2022
from executorch.backends.xnnpack.test.tester.tester import (
2123
Partition,
@@ -672,3 +674,150 @@ def _test_groupwise_dq_linear(
672674
.serialize()
673675
.run_method_and_compare_outputs(atol=atol, rtol=rtol)
674676
)
677+
678+
def _test_linear_overwrite_precision(
679+
self,
680+
make_module: Callable[[int, int], torch.nn.Module],
681+
uses_bias: bool,
682+
quant_type: str,
683+
quant_node_checks: List[Dict[str, int]],
684+
atol: float = 1e-03,
685+
):
686+
"""
687+
This test is to test the overwrite precision of linear op.
688+
We will test partitioning, lowering, and running the quantized linear model as fp32 linear op.
689+
When using legacy_mode, we will test we don't partition [add]mm given,
690+
(1) We can't assume that weights are always static (non param).
691+
(2) Alternatively, when lowering [add]mm to xnn::bmm we can't support bias.
692+
(2)(a) Only lowering non-bias [add]mm, which is only exposed on legacy_path deemed low ROI.
693+
"""
694+
695+
in_sizes = [3, 4, 4]
696+
input_sizes = [4, 37, 17]
697+
output_sizes = [4, 17, 37]
698+
699+
assert quant_type in ["per_tensor", "per_channel", "per_channel_dynamic"]
700+
per_channel = "per_channel" in quant_type
701+
dynamic = "dynamic" in quant_type
702+
quant_config = get_symmetric_quantization_config(
703+
is_per_channel=per_channel,
704+
is_dynamic=dynamic,
705+
)
706+
# Using FP32 partitioner for this quantized graph
707+
partitioner = XnnpackFloatingPointPartitioner()
708+
709+
def get_qnode_checks(quant_node_checks, dialect):
710+
d = {}
711+
assert dialect in ["aten", "edge"]
712+
if dialect == "aten":
713+
d = {
714+
f"torch.ops.quantized_decomposed.{op}": count
715+
for op, count in quant_node_checks.items()
716+
}
717+
elif dialect == "edge":
718+
d = {
719+
f"executorch.exir.dialects.edge._ops.quantized_decomposed.{op}".replace(
720+
".", "_"
721+
): count
722+
for op, count in quant_node_checks.items()
723+
}
724+
assert len(d) == len(quant_node_checks)
725+
return d
726+
727+
for i, _ in enumerate(in_sizes):
728+
torch._dynamo.reset()
729+
in_size = int(in_sizes[i])
730+
input_size = int(input_sizes[i])
731+
output_size = int(output_sizes[i])
732+
input_shape = [in_size] + [input_size]
733+
module = make_module(input_size, output_size).eval()
734+
inputs = (torch.randn(input_shape),)
735+
736+
addmm_op_str = (
737+
"executorch_exir_dialects_edge__ops_aten_addmm_default"
738+
if uses_bias
739+
else "executorch_exir_dialects_edge__ops_aten_mm_default"
740+
)
741+
linear_op_str = "executorch_exir_dialects_edge__ops_aten_linear_default"
742+
743+
for legacy_mode in (True, False):
744+
tester = (
745+
Tester(module, inputs)
746+
.quantize(Quantize(quantization_config=quant_config))
747+
.export()
748+
.dump_artifact()
749+
.check_count(get_qnode_checks(quant_node_checks, "aten"))
750+
)
751+
752+
if legacy_mode:
753+
tester.to_edge()
754+
tester.partition(Partition(partitioner=partitioner))
755+
# We don't expect [add]mm to be partitioned
756+
tester.check([addmm_op_str])
757+
else:
758+
tester.to_edge_transform_and_lower(
759+
ToEdgeTransformAndLower(partitioners=[partitioner])
760+
)
761+
# We do expect linear to be partitioned
762+
tester.check_not([linear_op_str])
763+
764+
# For legacy mode, fp32 permute_copy gets partitioned. (just a side effect)
765+
# For new mode, fp32 linear gets partitioned.
766+
tester.check_count(
767+
{"torch.ops.higher_order.executorch_call_delegate": 1}
768+
)
769+
770+
# Typically, we would not see any quantized ops in the graph.
771+
# But here we shouldn't partition these.
772+
tester.check_count(get_qnode_checks(quant_node_checks, "edge"))
773+
774+
# TODO: Need to figure out how to load quantized ops in pybindings.
775+
# tester.to_executorch()
776+
# tester.serialize()
777+
# tester.run_method_and_compare_outputs(
778+
# qtol=bool(quant_config), atol=atol
779+
# )
780+
781+
def test_qs8_as_fp32(self):
782+
for use_bias in (True, False):
783+
self._test_linear_overwrite_precision(
784+
lambda in_size, out_size: torch.nn.Linear(
785+
in_size, out_size, bias=use_bias # noqa
786+
),
787+
use_bias,
788+
"per_tensor",
789+
quant_node_checks={
790+
"quantize_per_tensor.default": 2, # 1: act, 1: output
791+
"dequantize_per_tensor.default": 3, # 1: act, 1: weight, 1: output
792+
},
793+
)
794+
795+
def test_qc8_as_fp32(self):
796+
for use_bias in (True, False):
797+
self._test_linear_overwrite_precision(
798+
lambda in_size, out_size: torch.nn.Linear(
799+
in_size, out_size, bias=use_bias # noqa
800+
),
801+
use_bias,
802+
"per_channel",
803+
quant_node_checks={
804+
"quantize_per_tensor.default": 2, # 1: act, 1: output
805+
"dequantize_per_tensor.default": 2, # 1: act, 1: output
806+
"dequantize_per_channel.default": 1, # 1: weight
807+
},
808+
)
809+
810+
def test_qd8_as_fp32(self):
811+
for use_bias in (True, False):
812+
self._test_linear_overwrite_precision(
813+
lambda in_size, out_size: torch.nn.Linear(
814+
in_size, out_size, bias=use_bias # noqa
815+
),
816+
use_bias,
817+
"per_channel_dynamic",
818+
quant_node_checks={
819+
"quantize_per_tensor.tensor": 1, # 1: act
820+
"dequantize_per_tensor.tensor": 1, # 1: act
821+
"dequantize_per_channel.default": 1, # 1: weight
822+
},
823+
)

0 commit comments

Comments
 (0)