Skip to content

Commit a29b208

Browse files
Arm backend: qdq folding support for remaining operators (#7340)
* Add TOSA table as custom edge op Edge operators that are lowered to TOSA TABLEs are convereted to a custom edge IR table-op. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I147008c30b9b46c7b8ae1a1c15bc540fea614a69 * Add support for concat q/dq folding This is a special case where node.args can be lists with many incoming dq-nodes. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Icf511a8bdeaaffb597b18455ab7f1fbd947ce3ca * Increase q/dq folding coverage Add support for q/dq folding of more operators such as hardtanh, maxpool2d, mul, relu, select, sub, to_copy. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Ifdabda4c927dade41c000859054696844c546f7b * Add support for sum q/dq folding sum is retraced to an int64 dtype of operator after q/dq folding. This patch adds a pass to manually force the dtype to be int8. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Ifa737a398c5a878d52cd76a2392499905da085ce * Complete q/dq folding coverage Add support for q/dq folding for the remaining supported ops in Arm backend. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I9012b4a501ce018c9771c729706be3b031a5c7ae * Remove is_quant_node from NodeVisitor.define_node Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Ibb17add461dc79e022a7f4accde29f9f9d61b16d * Fix pyre issues Address issues from pyre and add similar # pyre-ignores as in #7362. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I6feaa611dcd539b3b0d21a6a7dd696ef7db691ef --------- Signed-off-by: Oscar Andersson <[email protected]>
1 parent 62c6346 commit a29b208

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+822
-749
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import itertools
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass_utils import create_node
11+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
from torch.fx import GraphModule
15+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
16+
17+
18+
class AnnotateDecomposedMatmulPass(ExportPass):
19+
"""
20+
torch.matmul can be decomposed in many ways, for instance:
21+
dq -> matmul -> q can become
22+
dq -> repeat -> view -> bmm -> view -> dq which makes quantization folding
23+
difficult. This helper function find all matmul partitions and annotate its
24+
matmul-op (can be mm or bmm).
25+
"""
26+
27+
def call(self, graph_module: GraphModule) -> PassResult:
28+
matmul_partitions = get_source_partitions(
29+
graph_module.graph,
30+
[
31+
torch.matmul,
32+
],
33+
None,
34+
)
35+
matmul_partitions = list(
36+
itertools.chain.from_iterable(matmul_partitions.values())
37+
)
38+
matmul_targets = {
39+
exir_ops.edge.aten.mm.default,
40+
exir_ops.edge.aten.bmm.default,
41+
}
42+
for partition in matmul_partitions:
43+
quantized_input = all(
44+
input_node.target == dq_op for input_node in partition.input_nodes
45+
)
46+
matmul_node = [
47+
node for node in partition.nodes if node.target in matmul_targets
48+
][0]
49+
if quantized_input:
50+
matmul_args = matmul_node.all_input_nodes
51+
for i in range(len(matmul_args)):
52+
input_node = partition.input_nodes[i]
53+
matmul_input_node = matmul_args[i]
54+
# Remove partition input dq-node
55+
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
56+
graph_module.graph.erase_node(input_node)
57+
input_node_qargs = input_node.args[1:]
58+
with graph_module.graph.inserting_before(matmul_node):
59+
# Create new dq-node before matmul
60+
dq_node = create_node(
61+
graph=graph_module.graph,
62+
op_target=dq_op,
63+
)
64+
dq_node.args = (matmul_input_node, *input_node_qargs)
65+
matmul_node.replace_input_with(matmul_input_node, dq_node)
66+
67+
partition_output = list(partition.output_nodes[0].users)[0]
68+
quantized_output = partition_output.target == q_op
69+
if quantized_output:
70+
output_node_qargs = partition_output.args[1:]
71+
with graph_module.graph.inserting_after(matmul_node):
72+
# Create q-node after matmul
73+
q_node = create_node(
74+
graph=graph_module.graph,
75+
op_target=q_op,
76+
)
77+
matmul_node.replace_all_uses_with(q_node)
78+
q_node.args = (matmul_node, *output_node_qargs)
79+
# Remove partition output q-node
80+
partition_output.replace_all_uses_with(
81+
partition_output.all_input_nodes[0]
82+
)
83+
graph_module.graph.erase_node(partition_output)
84+
85+
# retrace the graph to update the fake tensor types
86+
graph_module = super().call(graph_module).graph_module
87+
88+
graph_module.recompile()
89+
return PassResult(graph_module, True)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
1212
AnnotateChannelsLastDimOrder,
1313
)
14+
from executorch.backends.arm._passes.annotate_decomposed_matmul import (
15+
AnnotateDecomposedMatmulPass,
16+
)
1417
from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
1518
from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass
1619
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
@@ -32,7 +35,9 @@
3235
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
3336
FoldAndAnnotateQParamsPass,
3437
QuantizeFullArgument,
38+
RetraceFoldedDtypesPass,
3539
)
40+
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
3641
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
3742
KeepDimsFalseToSqueezePass,
3843
)
@@ -67,24 +72,15 @@ def transform_to_backend_pipeline(
6772
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
6873
):
6974
"""Apply passes before transforming program to backend"""
70-
self.add_pass(CastInt64ToInt32Pass(exported_program))
75+
self.add_pass(DecomposeLinearPass())
7176
self.add_pass(RemoveGetItemPass())
72-
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
73-
self.add_pass(SizeAdjustConv2DPass())
74-
self.add_pass(RemoveClonePass())
75-
self.add_pass(ConvertExpandCopyToRepeatPass())
7677
self.add_pass(DecomposeLayerNormPass())
77-
self.add_pass(UnsqueezeBeforeRepeatPass())
7878
self.add_pass(DecomposeVarPass())
7979
self.add_pass(ConvertMeanDimToAveragePool())
8080
self.add_pass(DecomposeMeanDimPass())
81-
self.add_pass(MatchArgRanksPass(exported_program))
82-
self.add_pass(DecomposeDivPass())
83-
self.add_pass(KeepDimsFalseToSqueezePass())
8481
self.add_pass(ConvertSplitToSlicePass())
85-
self.add_pass(Conv1dUnsqueezePass(exported_program))
86-
self.add_pass(DecomposeSoftmaxesPass())
87-
self.add_pass(DecomposeLinearPass())
82+
# TODO MLETORCH-558
83+
self.add_pass(AnnotateDecomposedMatmulPass())
8884
self.add_pass(QuantizeFullArgument())
8985
self.add_pass(
9086
FoldAndAnnotateQParamsPass(
@@ -93,11 +89,49 @@ def transform_to_backend_pipeline(
9389
exir_ops.edge.aten.maximum.default,
9490
exir_ops.edge.aten.add.Tensor,
9591
exir_ops.edge.aten.avg_pool2d.default,
92+
exir_ops.edge.aten.bmm.default,
93+
exir_ops.edge.aten.cat.default,
9694
exir_ops.edge.aten.convolution.default,
95+
exir_ops.edge.aten.clone.default,
96+
exir_ops.edge.aten.exp.default,
97+
exir_ops.edge.aten.expand_copy.default,
9798
exir_ops.edge.aten.full.default,
99+
exir_ops.edge.aten.hardtanh.default,
100+
exir_ops.edge.aten.log.default,
101+
exir_ops.edge.aten.max_pool2d.default,
102+
exir_ops.edge.aten.mm.default,
103+
exir_ops.edge.aten.mul.Tensor,
104+
exir_ops.edge.aten.permute_copy.default,
105+
exir_ops.edge.aten.reciprocal.default,
106+
exir_ops.edge.aten.relu.default,
107+
exir_ops.edge.aten.repeat.default,
108+
exir_ops.edge.aten.rsqrt.default,
109+
exir_ops.edge.aten.select_copy.int,
110+
exir_ops.edge.aten.sigmoid.default,
111+
exir_ops.edge.aten.slice_copy.Tensor,
112+
exir_ops.edge.aten.squeeze_copy.dims,
113+
exir_ops.edge.aten.sub.Tensor,
114+
exir_ops.edge.aten.sum.dim_IntList,
115+
exir_ops.edge.aten.tanh.default,
116+
exir_ops.edge.aten.unsqueeze_copy.default,
117+
exir_ops.edge.aten.upsample_nearest2d.vec,
118+
exir_ops.edge.aten.view_copy.default,
98119
]
99120
)
100121
)
122+
self.add_pass(RetraceFoldedDtypesPass())
123+
self.add_pass(InsertTableOpsPass(exported_program))
124+
self.add_pass(ConvertExpandCopyToRepeatPass())
125+
self.add_pass(UnsqueezeBeforeRepeatPass())
126+
self.add_pass(CastInt64ToInt32Pass(exported_program))
127+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
128+
self.add_pass(SizeAdjustConv2DPass())
129+
self.add_pass(RemoveClonePass())
130+
self.add_pass(MatchArgRanksPass(exported_program))
131+
self.add_pass(DecomposeDivPass())
132+
self.add_pass(KeepDimsFalseToSqueezePass())
133+
self.add_pass(Conv1dUnsqueezePass(exported_program))
134+
self.add_pass(DecomposeSoftmaxesPass())
101135
for spec in compile_spec:
102136
if spec.key == "permute_memory_format":
103137
memory_format = spec.value.decode()

backends/arm/_passes/conv1d_unsqueeze_pass.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
from executorch.backends.arm._passes.arm_pass_utils import (
1313
create_node,
1414
get_param_tensor,
15-
insert_q_dq_pair,
1615
is_param_node,
1716
)
18-
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
1917
from executorch.exir import ExportedProgram
2018
from executorch.exir.dialects._ops import ops as exir_ops
2119
from executorch.exir.pass_base import ExportPass, PassResult
@@ -27,10 +25,8 @@ class Conv1dUnsqueezePass(ExportPass):
2725
supports 2d and 3d convolution. This is done by modifying the graph to do the
2826
following:
2927
1) unsqueeze the convolution's input from 3d to 4d
30-
2) if the input to unsqueeze is quantized, insert q/dq-pair after unsqueeze
31-
3) perform a conv2d (with a modified version of the original conv1d args)
32-
4) squeeze the output back down to 3d.
33-
5) if all users of squeeze are quantized, insert q/dq-pair before squeeze
28+
2) perform a conv2d (with a modified version of the original conv1d args)
29+
3) squeeze the output back down to 3d.
3430
"""
3531

3632
def __init__(self, exported_program: ExportedProgram) -> None:
@@ -94,8 +90,6 @@ def call(self, graph_module: torch.fx.GraphModule):
9490
continue
9591

9692
kernel_node = node.args[1]
97-
if kernel_node.target == dq_op:
98-
kernel_node = kernel_node.args[0]
9993

10094
if not is_param_node(self.exported_program, kernel_node):
10195
raise AssertionError(
@@ -131,11 +125,6 @@ def call(self, graph_module: torch.fx.GraphModule):
131125
)
132126
node.replace_input_with(input_node, unsqueeze_before)
133127

134-
# If Quantized we must insert unsqueeze --> q --> dq --> node
135-
if input_node.target == dq_op:
136-
q_params = input_node.args[1:]
137-
insert_q_dq_pair(graph, unsqueeze_before, q_params)
138-
139128
with graph.inserting_after(node):
140129
squeeze_after = create_node(
141130
graph,
@@ -151,13 +140,6 @@ def call(self, graph_module: torch.fx.GraphModule):
151140
for user in original_users:
152141
user.replace_input_with(node, squeeze_after)
153142

154-
# If quantized, insert conv2d --> q --> dq --> squeeze
155-
if all(
156-
original_user.target == q_op for original_user in original_users
157-
):
158-
q_params = original_users[0].args[1:]
159-
insert_q_dq_pair(graph, node, q_params)
160-
161143
graph_module.recompile()
162144
# Since we are overriding "call", we need to call the parent's "call"
163145
# to retrace the graph and regenerate metadata

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 88 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,20 @@
66

77
import copy
88

9-
from typing import cast, Iterable
9+
from typing import cast, Dict, Iterable, Set, Tuple
1010

1111
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1212

1313
from executorch.exir.dialects._ops import ops as exir_ops
1414
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1515

16-
from executorch.exir.pass_base import ExportPass, PassResult
16+
from executorch.exir.pass_base import (
17+
Argument,
18+
ExportPass,
19+
NodeMetadata,
20+
PassResult,
21+
ProxyValue,
22+
)
1723
from torch.fx import GraphModule, Node
1824

1925
q_op: EdgeOpOverload = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
@@ -80,6 +86,46 @@ def __init__(self, targeted_ops: Iterable[EdgeOpOverload]) -> None:
8086
super().__init__()
8187
self.targeted_ops = targeted_ops
8288

89+
def fold_and_annotate_arg(
90+
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
91+
) -> None:
92+
input_qparams = None
93+
nodes_to_remove = set()
94+
for arg in arg_list:
95+
if not isinstance(arg, Node):
96+
return
97+
"""
98+
Make sure arg has requires_grad set to False
99+
For parameters that are not quantized, sometimes (i.e. convolution)
100+
the Parameter(FakeTensor(...)) has requires_grad set to True, which
101+
causes the retracing of the graph to fail with:
102+
103+
E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.
104+
E
105+
E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
106+
E Original traceback:
107+
E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward
108+
E x = conv(x)
109+
"""
110+
if arg.op == "placeholder":
111+
arg.meta["val"].requires_grad = False
112+
113+
arg_quant_params = None
114+
if arg.target == dq_op:
115+
arg_quant_params = QuantArgs.from_operator(arg.target, arg.args)
116+
# add arg to nodes_to_remove to fold the dq-node
117+
nodes_to_remove.add(arg)
118+
if input_qparams is not None and input_qparams != arg_quant_params:
119+
# Two args are quantized differently
120+
raise RuntimeError("Input qparams does not match!")
121+
input_qparams = arg_quant_params
122+
if input_qparams is not None:
123+
node.meta["input_qparams"][i] = input_qparams
124+
for n in nodes_to_remove:
125+
assert n.target == dq_op
126+
n.replace_all_uses_with(n.args[0])
127+
graph_module.graph.erase_node(n)
128+
83129
def call(self, graph_module: GraphModule) -> PassResult:
84130

85131
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
@@ -98,36 +144,11 @@ def call(self, graph_module: GraphModule) -> PassResult:
98144
n.meta["input_qparams"] = {}
99145
n.meta["output_qparams"] = {}
100146
for i, arg in enumerate(n.args):
101-
if not isinstance(arg, Node):
102-
continue
103-
104-
# Make sure arg has requires_grad set to False
105-
# For parameters that are not quantized, sometimes (i.e. convolution)
106-
# the Parameter(FakeTensor(...)) has requires_grad set to True, which
107-
# causes the retracing of the graph to fail with:
108-
#
109-
# E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.
110-
# E
111-
# E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
112-
# E Original traceback:
113-
# E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward
114-
# E x = conv(x)
115-
#
116-
if arg.op == "placeholder":
117-
arg.meta["val"].requires_grad = False
118-
119-
if arg.target != dq_op:
120-
continue
121-
122-
# arg.target for argument i is a dequant node, extract the information
123-
n.meta["input_qparams"][i] = QuantArgs.from_operator(
124-
arg.target, arg.args
125-
)
147+
if isinstance(arg, list):
148+
self.fold_and_annotate_arg(graph_module, n, arg, i)
126149

127-
# arg.args[0] is the tensor input, replace the input usage
128-
tensor_input = cast(Node, arg.args[0])
129-
n.replace_input_with(arg, tensor_input)
130-
graph_module.graph.erase_node(arg)
150+
elif isinstance(arg, Node):
151+
self.fold_and_annotate_arg(graph_module, n, [arg], i)
131152

132153
# Copy the users, since we are modifying it.
133154
users_copy = copy.copy(n.users)
@@ -181,3 +202,39 @@ def call(self, graph_module: GraphModule) -> PassResult:
181202
modified = True
182203

183204
return PassResult(graph_module, modified)
205+
206+
207+
class RetraceFoldedDtypesPass(ExportPass):
208+
"""
209+
FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced
210+
some operators are retraced to types that cannot be handled by TOSA. One
211+
such example is sum.dim_IntList:
212+
q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ...
213+
After folding it becomes:
214+
q (int8) -> sum (int64) -> ...
215+
This pass changes types of ops in self.targeted_ops, such as sum, so that
216+
the output type of that matches the type of the output_qparams.
217+
"""
218+
219+
targeted_ops: Set[EdgeOpOverload] = {
220+
exir_ops.edge.aten.sum.dim_IntList,
221+
}
222+
223+
def call_operator(
224+
self,
225+
op, # pyre-ignore
226+
args: Tuple[Argument, ...],
227+
kwargs: Dict[str, Argument],
228+
meta: NodeMetadata,
229+
) -> ProxyValue:
230+
if op not in self.targeted_ops:
231+
return super().call_operator(op, args, kwargs, meta)
232+
233+
node_kwargs = kwargs.copy()
234+
output_qparams = meta["output_qparams"]
235+
if len(output_qparams) == 0:
236+
return super().call_operator(op, args, kwargs, meta)
237+
238+
output_dtype = output_qparams[0].dtype
239+
node_kwargs["dtype"] = output_dtype
240+
return super().call_operator(op, args, node_kwargs, meta)

0 commit comments

Comments
 (0)