Skip to content

Commit 75b0d8a

Browse files
authored
Revert "Revert "Add full operator to fold dq/q handling" (#7351)" (#7362)
* Revert "[Arm] Fix merge conflicts after previous reverts (#7353)" This reverts commit 44e31fb. Change-Id: I12d88419e45a800e43e1e31d21280bc3b63710c6 * Revert "Revert "Add full operator to fold dq/q handling" (#7351)" This reverts commit 11beed1. Change-Id: I6ba9b37f069c3ad819114fe1384659bc0f68135f * Fix type checker issues Signed-off-by: Per Åstrand <[email protected]> Change-Id: Iee35eeb4af28a037848570b1d5143380222f549d * Address pyre checks from Meta internal tools Signed-off-by: Per Åstrand <[email protected]> Change-Id: I60fb32ad55a3f6d3617993481ab0c1ed46cf778c * Fix more typecheckings Change-Id: I37fb7468b5cb040b916fabb2784ec475318bddba * Reapply pyre-ignores after lintrunner moves lintrunner moved the pyre-ignores to the wrong line, re-apply to the right line again. Signed-off-by: Per Åstrand <[email protected]> Change-Id: Icbb6cf6f2d65d6fe1c00e7c4a738691ffeb7acb2 --------- Signed-off-by: Per Åstrand <[email protected]>
1 parent f341da8 commit 75b0d8a

File tree

83 files changed

+1127
-287
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+1127
-287
lines changed

backends/arm/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ The Arm Backend should be considered a prototype quality at this point, likely s
104104
## Current flows
105105

106106
The ArmBackend has a two stage process,
107-
- Compile to TOSA to rationalise the graph into known hardware support profiles. Currently this is to v0.80.0 TOSA BI with specific concern to a subset which gives support on Ethos-U55, the target of the initial prototype efforts.
108-
- Lower via the ethos-u-vela compilation flow which takes TOSA v0.80.0 as an input and produces a low level commandstream for the hardware which is then passed via the delegate to the ethos-u-core-driver for direct execution.
107+
- Compile to TOSA to rationalise the graph into known hardware support profiles. Currently this is to v0.80 TOSA BI with specific concern to a subset which gives support on Ethos-U55, the target of the initial prototype efforts.
108+
- Lower via the ethos-u-vela compilation flow which takes TOSA v0.80 as an input and produces a low level commandstream for the hardware which is then passed via the delegate to the ethos-u-core-driver for direct execution.
109109

110110
The ArmPartitioner is currenly used to ensure the operations converted are Ethos-U compatible, but will be extended to offer spec-correct TOSA Base inference and TOSA Main Inference generation in future.
111111

backends/arm/_passes/arm_pass_manager.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
DecomposeSoftmaxesPass,
3030
)
3131
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
32+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
33+
FoldAndAnnotateQParamsPass,
34+
QuantizeFullArgument,
35+
)
3236
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
3337
KeepDimsFalseToSqueezePass,
3438
)
@@ -50,6 +54,7 @@
5054
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
5155
from executorch.exir import ExportedProgram
5256
from executorch.exir.backend.compile_spec_schema import CompileSpec
57+
from executorch.exir.dialects._ops import ops as exir_ops
5358
from executorch.exir.pass_manager import PassManager
5459

5560

@@ -80,6 +85,19 @@ def transform_to_backend_pipeline(
8085
self.add_pass(Conv1dUnsqueezePass(exported_program))
8186
self.add_pass(DecomposeSoftmaxesPass())
8287
self.add_pass(DecomposeLinearPass())
88+
self.add_pass(QuantizeFullArgument())
89+
self.add_pass(
90+
FoldAndAnnotateQParamsPass(
91+
[
92+
exir_ops.edge.aten.minimum.default,
93+
exir_ops.edge.aten.maximum.default,
94+
exir_ops.edge.aten.add.Tensor,
95+
exir_ops.edge.aten.avg_pool2d.default,
96+
exir_ops.edge.aten.convolution.default,
97+
exir_ops.edge.aten.full.default,
98+
]
99+
)
100+
)
83101
for spec in compile_spec:
84102
if spec.key == "permute_memory_format":
85103
memory_format = spec.value.decode()
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
9+
from typing import cast, Iterable
10+
11+
from executorch.backends.arm.tosa_quant_utils import QuantArgs
12+
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
15+
16+
from executorch.exir.pass_base import ExportPass, PassResult
17+
from torch.fx import GraphModule, Node
18+
19+
q_op: EdgeOpOverload = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
20+
dq_op: EdgeOpOverload = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
21+
22+
23+
def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
24+
"""
25+
Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
26+
Raises a ValueError if the node doesn't have any parameters set.
27+
"""
28+
if "input_qparams" not in node.meta.keys():
29+
raise ValueError(f"No input quantization parameter found in node {node}")
30+
input_qparams = cast(dict[int, QuantArgs], node.meta["input_qparams"])
31+
if len(input_qparams) == 0:
32+
raise ValueError(f"No input quantization parameter found in node {node}")
33+
return input_qparams
34+
35+
36+
def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
37+
"""
38+
Get the output quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
39+
Raises a ValueError if the node doesn't have any parameters set.
40+
"""
41+
if "output_qparams" not in node.meta.keys():
42+
raise ValueError(f"No output quantization parameter found in node {node}")
43+
input_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"])
44+
if len(input_qparams) == 0:
45+
raise ValueError(f"No output quantization parameter found in node {node}")
46+
return input_qparams
47+
48+
49+
class FoldAndAnnotateQParamsPass(ExportPass):
50+
"""
51+
A pass that walks the graph and removes any DQ and Q nodes before and after the target
52+
node in the supplied list of operators.
53+
The quantization parameters from the DQ/Q nodes are stored as meta values to be
54+
accessible for later lowering and serialization passes.
55+
The assumption is that the quantization annotatation adds DQ nodes for all tensor
56+
inputs to the target one Q node to the output.
57+
58+
Example ('executorch_exir_dialects_edge__ops_' prefix removed from operators for readability):
59+
60+
x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8)
61+
62+
x_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(x_q, 0.05487706884741783, -128, -128, 127, torch.int8)
63+
aten_add_tensor: "f32[5]" = ops_aten_add_Tensor(x_dq, x_dq)
64+
aten_add_tensor_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(aten_add_tensor, 0.05487706884741783, -128, -128, 127, torch.int8)
65+
66+
output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)
67+
68+
Becomes:
69+
x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8)
70+
71+
aten_add_tensor: "i8[5]" = aten_add_Tensor(x_q, x_q)
72+
73+
output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)
74+
75+
The quantization parameters for x_dq and aten_add_tensor_q are store in meta for the aten_add_tensor node.
76+
77+
"""
78+
79+
def __init__(self, targeted_ops: Iterable[EdgeOpOverload]) -> None:
80+
super().__init__()
81+
self.targeted_ops = targeted_ops
82+
83+
def call(self, graph_module: GraphModule) -> PassResult:
84+
85+
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
86+
for n in graph_module.graph.nodes:
87+
n = cast(Node, n)
88+
if n.op != "call_function" or n.target not in self.targeted_ops:
89+
continue
90+
91+
# Make sure we haven't already set qparams meta information on the node
92+
assert "input_qparams" not in n.meta.keys()
93+
assert "output_qparams" not in n.meta.keys()
94+
95+
# for the inputs and outputs search the graph for quantization info and
96+
# store the information in a dict with order of the _tensor_ inputs as key,
97+
# ignoring any other arguments to the target node.
98+
n.meta["input_qparams"] = {}
99+
n.meta["output_qparams"] = {}
100+
for i, arg in enumerate(n.args):
101+
if not isinstance(arg, Node):
102+
continue
103+
104+
# Make sure arg has requires_grad set to False
105+
# For parameters that are not quantized, sometimes (i.e. convolution)
106+
# the Parameter(FakeTensor(...)) has requires_grad set to True, which
107+
# causes the retracing of the graph to fail with:
108+
#
109+
# E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.
110+
# E
111+
# E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
112+
# E Original traceback:
113+
# E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward
114+
# E x = conv(x)
115+
#
116+
if arg.op == "placeholder":
117+
arg.meta["val"].requires_grad = False
118+
119+
if arg.target != dq_op:
120+
continue
121+
122+
# arg.target for argument i is a dequant node, extract the information
123+
n.meta["input_qparams"][i] = QuantArgs.from_operator(
124+
arg.target, arg.args
125+
)
126+
127+
# arg.args[0] is the tensor input, replace the input usage
128+
tensor_input = cast(Node, arg.args[0])
129+
n.replace_input_with(arg, tensor_input)
130+
graph_module.graph.erase_node(arg)
131+
132+
# Copy the users, since we are modifying it.
133+
users_copy = copy.copy(n.users)
134+
for i, user in enumerate(users_copy):
135+
if user.target != q_op:
136+
continue
137+
138+
# quantization node found here, store the quantization parameters in meta value
139+
n.meta["output_qparams"][i] = QuantArgs.from_operator(
140+
user.target, user.args
141+
)
142+
143+
user.replace_all_uses_with(n)
144+
graph_module.graph.erase_node(user)
145+
146+
# retrace the graph to update the fake tensor types
147+
graph_module = super().call(graph_module).graph_module
148+
149+
graph_module.recompile()
150+
return PassResult(graph_module, True)
151+
152+
153+
class QuantizeFullArgument(ExportPass):
154+
"""
155+
Make sure the fill_value for full.default is quantized. This pass needs to be run before
156+
the folding pass above to make sure that the retraced output of the full.default op is
157+
the right dtype.
158+
"""
159+
160+
def call(self, graph_module: GraphModule) -> PassResult:
161+
modified = False
162+
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
163+
for n in graph_module.graph.nodes:
164+
n = cast(Node, n)
165+
if n.target != exir_ops.edge.aten.full.default:
166+
continue
167+
168+
# Make sure we have a quantized operator
169+
user = list(n.users)[0]
170+
if user.target != q_op:
171+
continue
172+
173+
qargs = QuantArgs.from_operator(user.target, user.args)
174+
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
175+
# replace the node arg with a quantized dito and also set dtype
176+
# to get the right output according to the Edge IR specification:
177+
# exir/dialects/edge/edge.yaml:3596
178+
quantized_full_value = qargs.quantize_value(n.args[1]).item()
179+
n.update_arg(1, quantized_full_value)
180+
n.update_kwarg("dtype", qargs.dtype)
181+
modified = True
182+
183+
return PassResult(graph_module, modified)

backends/arm/arm_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def ethosu_compile_spec(
9090
if extra_flags is not None:
9191
self.compiler_flags.append(extra_flags)
9292

93-
base_tosa_version = "TOSA-0.80.0+BI"
93+
base_tosa_version = "TOSA-0.80+BI"
9494
if "u55" in config:
9595
# Add the Ethos-U55 extension marker
9696
base_tosa_version += "+u55"

backends/arm/operator_support/right_shift_support.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class RightShiftSupported(SupportedTOSAOperatorCheck):
2323
targets = [exir_ops.edge.aten.__rshift__.Scalar]
2424

2525
tosa_specs = [
26-
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
27-
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
26+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
27+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
2828
]
2929

3030
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

backends/arm/operator_support/to_copy_support.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ class ToCopySupported(SupportedTOSAOperatorCheck):
2525
targets = [exir_ops.edge.aten._to_copy.default]
2626

2727
tosa_specs = [
28-
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
29-
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
28+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
29+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
3030
]
3131

3232
SupportedTypeDict = dict[torch.dtype, list[torch.dtype]]

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
3535
_tosa_spec_dicts: dict[
3636
TosaSpecification, dict[str, Type[SupportedTOSAOperatorCheck]]
3737
] = {
38-
TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {},
39-
TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {},
38+
TosaSpecification.create_from_string("TOSA-0.80+BI"): {},
39+
TosaSpecification.create_from_string("TOSA-0.80+MI"): {},
4040
}
4141

4242

@@ -94,6 +94,8 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
9494
exir_ops.edge.aten.sigmoid.default,
9595
exir_ops.edge.aten.mean.dim,
9696
exir_ops.edge.aten.mm.default,
97+
exir_ops.edge.aten.minimum.default,
98+
exir_ops.edge.aten.maximum.default,
9799
exir_ops.edge.aten.repeat.default,
98100
exir_ops.edge.aten.reciprocal.default,
99101
exir_ops.edge.aten.relu.default,

backends/arm/operators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
op_get_item,
2020
op_hardtanh,
2121
op_log,
22+
op_max,
2223
op_max_pool2d,
24+
op_min,
2325
op_mm,
2426
op_mul,
2527
op_permute,

backends/arm/operators/node_visitor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ class NodeVisitor:
2525
# When all node_visitors has been refactored to target a specific
2626
# version, this list should be removed.
2727
tosa_specs = [
28-
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
29-
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
28+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
29+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
3030
]
3131

3232
def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification):
@@ -46,8 +46,8 @@ def define_node(
4646

4747
# container for all node visitors
4848
_node_visitor_dicts = {
49-
TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {},
50-
TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {},
49+
TosaSpecification.create_from_string("TOSA-0.80+BI"): {},
50+
TosaSpecification.create_from_string("TOSA-0.80+MI"): {},
5151
}
5252

5353

0 commit comments

Comments
 (0)