Skip to content

Commit 412b5b0

Browse files
authored
Merge branch 'main' into matmul_unmark_flaky
2 parents 4f40c44 + b73f9d5 commit 412b5b0

File tree

102 files changed

+428
-349
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

102 files changed

+428
-349
lines changed

.github/workflows/_link_check.yml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ jobs:
1818
timeout: 120
1919
script: |
2020
./scripts/lint_urls.sh $(
21-
{ [ "${{ github.event_name }}" = "pull_request" ] \
22-
&& git diff --name-only "${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }}"; } \
23-
|| \
24-
{ [ "${{ github.event_name }}" = "push" ] \
25-
&& git diff --name-only "${{ github.event.before }}...${{ github.sha }}"; }
21+
if [ "${{ github.event_name }}" = "pull_request" ]; then
22+
echo "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}"
23+
else
24+
echo "${{ github.event.before }}" "${{ github.sha }}"
25+
fi
2626
) || {
2727
echo
2828
echo "URL lint failed."
@@ -43,11 +43,11 @@ jobs:
4343
timeout: 60
4444
script: |
4545
./scripts/lint_xrefs.sh $(
46-
{ [ "${{ github.event_name }}" = "pull_request" ] \
47-
&& git diff --name-only "${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }}"; } \
48-
|| \
49-
{ [ "${{ github.event_name }}" = "push" ] \
50-
&& git diff --name-only "${{ github.event.before }}...${{ github.sha }}"; }
46+
if [ "${{ github.event_name }}" = "pull_request" ]; then
47+
echo "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}"
48+
else
49+
echo "${{ github.event.before }}" "${{ github.sha }}"
50+
fi
5151
) || {
5252
echo
5353
echo "Xref lint failed."

backends/qualcomm/_passes/annotate_quant_attrs.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,17 @@
77
from typing import Any, Dict
88

99
import torch
10-
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
10+
from executorch.backends.qualcomm.builders.utils import get_parameter
1111
from executorch.backends.qualcomm.utils.constants import (
12-
QCOM_AXIS,
13-
QCOM_BLOCK_SIZE,
1412
QCOM_DTYPE,
1513
QCOM_ENCODING,
1614
QCOM_QUANT_ATTRS,
1715
QCOM_QUANT_MAX,
1816
QCOM_QUANT_MIN,
1917
QCOM_REQUANTIZE,
2018
QCOM_SCALE,
21-
QCOM_SCALES,
2219
QCOM_ZERO_POINT,
23-
QCOM_ZERO_POINTS,
2420
)
25-
from executorch.exir.dialects._ops import ops as exir_ops
2621
from executorch.exir.pass_base import ExportPass, PassResult
2722

2823
from .utils import dq_ops, get_quant_attrs, q_ops
@@ -101,43 +96,9 @@ def _annotate_requant(self, n):
10196
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
10297
n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs
10398

104-
# Dequant all the fold_quant parameters back to fp32.
105-
# If an operation is not supported by QNN and got fallback, it will expect a fp32 param.
106-
def _dequant_fold_params(self, n, quant_attrs, param):
107-
if quant_attrs[QCOM_ENCODING] in [
108-
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default
109-
]:
110-
dim, axis = param.dim(), quant_attrs[QCOM_AXIS]
111-
scales = self._expand(quant_attrs[QCOM_SCALES], dim, axis)
112-
offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis)
113-
param = param.sub(offsets).mul(scales).to(torch.float32).contiguous()
114-
elif quant_attrs[QCOM_ENCODING] in [
115-
exir_ops.edge.pt2e_quant.dequantize_affine.default
116-
]:
117-
param = torch.ops.pt2e_quant.dequantize_affine(
118-
param,
119-
block_size=quant_attrs[QCOM_BLOCK_SIZE],
120-
scale=quant_attrs[QCOM_SCALE],
121-
zero_point=quant_attrs[QCOM_ZERO_POINT],
122-
input_dtype=quant_attrs[QCOM_DTYPE],
123-
quant_min=quant_attrs[QCOM_QUANT_MIN],
124-
quant_max=quant_attrs[QCOM_QUANT_MAX],
125-
output_dtype=torch.float32,
126-
)
127-
else:
128-
scale = quant_attrs[QCOM_SCALE]
129-
offset = quant_attrs[QCOM_ZERO_POINT]
130-
param = param.sub(offset).mul(scale).to(torch.float32).contiguous()
131-
132-
set_parameter(param, n.args[0], self.edge_program)
133-
n.args[0].meta["val"] = param
134-
13599
def _annotate_quant_attrs(
136100
self, graph_module: torch.fx.GraphModule
137101
) -> torch.fx.GraphModule:
138-
# Keep track of const params that has been dequant, so it does not get
139-
# dequant multiple times if the const param has more than 1 user
140-
visited_const_param = set()
141102
for n in graph_module.graph.nodes:
142103
self._annotate_requant(n)
143104
# With fold_quant enabled, check if the input of dq op is quantized param.
@@ -149,10 +110,6 @@ def _annotate_quant_attrs(
149110
quant_attrs = get_quant_attrs(self.edge_program, n)
150111
self._annotate_source_nodes(n, quant_attrs)
151112

152-
if param is not None and n.args[0] not in visited_const_param:
153-
visited_const_param.add(n.args[0])
154-
self._dequant_fold_params(n, quant_attrs, param)
155-
156113
return graph_module
157114

158115
def call(self, graph_module: torch.fx.GraphModule):

backends/qualcomm/_passes/convert_conv1d_to_conv2d.py

Lines changed: 94 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
import torch.nn as nn
98
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
109
from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE
11-
from executorch.exir.dialects._ops import ops as exir_ops
1210
from executorch.exir.pass_base import ExportPass, PassResult
1311

1412
from .utils import copy_meta
@@ -23,16 +21,43 @@ class ConvertConv1dToConv2d(ExportPass):
2321
def __init__(self, edge_program: torch.export.ExportedProgram):
2422
super(ConvertConv1dToConv2d, self).__init__()
2523
self.edge_program = edge_program
24+
self.conv_op_map = {
25+
torch.ops.aten.conv1d.default: torch.ops.aten.conv2d.default,
26+
torch.ops.aten.conv_transpose1d.default: torch.ops.aten.conv_transpose2d.input,
27+
}
28+
29+
def append_qdq(
30+
self,
31+
graph_module: torch.fx.GraphModule,
32+
node: torch.fx.Node,
33+
qdq_node: torch.fx.Node,
34+
):
35+
q_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
36+
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
37+
if qdq_node.target not in {q_op, dq_op}:
38+
return node
39+
40+
with graph_module.graph.inserting_after(node):
41+
q_args = (node, *qdq_node.args[1:])
42+
q_node = graph_module.graph.create_node("call_function", q_op, q_args)
43+
q_node.meta = copy_meta(node.meta)
44+
q_node.meta["val"] = q_node.meta["val"].to(q_args[-1])
45+
with graph_module.graph.inserting_after(q_node):
46+
dq_args = (q_node, *qdq_node.args[1:])
47+
dq_node = graph_module.graph.create_node(
48+
"call_function", dq_op, dq_args
49+
)
50+
dq_node.meta = copy_meta(node.meta)
51+
52+
return dq_node
2653

2754
def call(self, graph_module: torch.fx.GraphModule):
2855
graph = graph_module.graph
29-
conv_op = exir_ops.edge.aten.convolution.default
3056
for node in graph.nodes:
31-
if node.target == conv_op and node.meta["val"].dim() == 3:
32-
57+
if node.target in self.conv_op_map:
3358
input_node = node.args[0]
3459
with graph_module.graph.inserting_after(input_node):
35-
unsqueeze_op = exir_ops.edge.aten.unsqueeze_copy.default
60+
unsqueeze_op = torch.ops.aten.unsqueeze_copy.default
3661
unsqueeze_node = graph.create_node(
3762
"call_function",
3863
unsqueeze_op,
@@ -44,52 +69,88 @@ def call(self, graph_module: torch.fx.GraphModule):
4469
unsqueeze_node.meta = copy_meta(
4570
input_node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
4671
)
72+
qdq_node_after_unsqueeze = self.append_qdq(
73+
graph_module=graph_module,
74+
node=unsqueeze_node,
75+
qdq_node=input_node,
76+
)
4777

48-
with graph_module.graph.inserting_after(unsqueeze_node):
49-
50-
filter_node = node.args[1]
78+
with graph_module.graph.inserting_after(qdq_node_after_unsqueeze):
79+
filter_arg = node.args[1]
80+
filter_node = (
81+
filter_arg
82+
if filter_arg.op == "placeholder"
83+
else node.args[1].args[0]
84+
)
5185
filter_node.meta["val"] = (
5286
filter_node.meta["val"].unsqueeze(2).contiguous()
5387
)
54-
filter_tensor = get_parameter(filter_node, self.edge_program)
55-
# Ensure tensor is nn.Parameter type, so program does not fail during edge_program._validate()
56-
filter_tensor = nn.Parameter(filter_tensor.unsqueeze(2))
57-
set_parameter(filter_tensor, filter_node, self.edge_program)
88+
filter_tensor = get_parameter(
89+
filter_node, self.edge_program
90+
).unsqueeze(2)
91+
set_parameter(
92+
(
93+
torch.nn.Parameter(filter_tensor)
94+
if filter_tensor.dtype == torch.float
95+
else filter_tensor
96+
),
97+
filter_node,
98+
self.edge_program,
99+
)
58100

101+
num_args = len(node.args)
59102
bias_node = node.args[2]
60-
stride = [1] + node.args[3]
61-
padding = [0] + node.args[4]
62-
dilation = [1] + node.args[5]
63-
transpose = node.args[6]
64-
output_padding = [0] + node.args[7]
65-
groups = node.args[8]
66-
67-
conv2d_node = graph.create_node(
68-
"call_function",
69-
conv_op,
70-
(
71-
unsqueeze_node,
72-
filter_node,
103+
stride = [1] + node.args[3] if num_args > 3 else [1, 1]
104+
padding = [0] + node.args[4] if num_args > 4 else [0, 0]
105+
if node.target == torch.ops.aten.conv1d.default:
106+
dilation = [1] + node.args[5] if num_args > 5 else [1, 1]
107+
groups = node.args[6] if num_args > 5 else 1
108+
conv_args = (
109+
qdq_node_after_unsqueeze,
110+
node.args[1],
73111
bias_node,
74112
stride,
75113
padding,
76114
dilation,
77-
transpose,
115+
groups,
116+
)
117+
else:
118+
output_padding = (
119+
[0] + node.args[5] if num_args > 5 else [0, 0]
120+
)
121+
groups = node.args[6] if num_args > 6 else 1
122+
dilation = [1] + node.args[7] if num_args > 7 else [1, 1]
123+
conv_args = (
124+
qdq_node_after_unsqueeze,
125+
node.args[1],
126+
bias_node,
127+
stride,
128+
padding,
78129
output_padding,
79130
groups,
80-
),
131+
dilation,
132+
)
133+
conv2d_node = graph.create_node(
134+
"call_function",
135+
self.conv_op_map[node.target],
136+
conv_args,
81137
)
82138
conv2d_node.meta = copy_meta(
83139
node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
84140
)
141+
qdq_node_after_conv2d = self.append_qdq(
142+
graph_module=graph_module,
143+
node=conv2d_node,
144+
qdq_node=list(node.users)[0],
145+
)
85146

86-
with graph_module.graph.inserting_after(conv2d_node):
87-
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
147+
with graph_module.graph.inserting_after(qdq_node_after_conv2d):
148+
squeeze_op = torch.ops.aten.squeeze_copy.dims
88149
squeeze_node = graph.create_node(
89150
"call_function",
90151
squeeze_op,
91152
(
92-
conv2d_node,
153+
qdq_node_after_conv2d,
93154
[2],
94155
),
95156
)
@@ -102,8 +163,10 @@ def call(self, graph_module: torch.fx.GraphModule):
102163
QCOM_REQUANTIZE
103164
]
104165
conv2d_node.meta.pop(QCOM_REQUANTIZE, None)
166+
105167
for user in node.users.copy():
106168
user.replace_input_with(node, squeeze_node)
169+
107170
graph.eliminate_dead_code()
108171
graph_module.recompile()
109172
return PassResult(graph_module, True)

backends/qualcomm/_passes/expand_broadcast_tensor_shape.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from executorch.exir.pass_base import ExportPass, PassResult
1010
from executorch.exir.passes import dead_code_elimination_pass
1111

12+
from .utils import dq_ops
13+
1214

1315
class ExpandBroadcastTensorShape(ExportPass):
1416
"""
@@ -45,9 +47,13 @@ def traverse_broadcast_node(self, graph_module: torch.fx.GraphModule):
4547
exir_ops.edge.aten.view_copy.default,
4648
(arg, tuple(new_rank)),
4749
)
50+
# try skip dq_ops to get correct param node if applicable
51+
arg_meta = (
52+
arg.args[0].meta if arg.target in dq_ops else arg.meta
53+
)
4854
# meta needs to be copied elementwisely for fake-tensor
4955
# to be updated correctly and not affect meta of arg
50-
for k, v in arg.meta.items():
56+
for k, v in arg_meta.items():
5157
reshape_node.meta[k] = v
5258
reshape_node.meta["val"] = reshape_node.meta["val"].reshape(
5359
new_rank

backends/qualcomm/_passes/fold_qdq.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
from executorch.backends.qualcomm.builders.utils import is_parameter
8+
from executorch.backends.qualcomm.utils.constants import QCOM_BYPASS_NODE
79
from executorch.exir.dialects._ops import ops as exir_ops
810
from executorch.exir.pass_base import ExportPass, PassResult
911
from executorch.exir.passes import dead_code_elimination_pass
@@ -16,23 +18,38 @@ class FoldQDQ(ExportPass):
1618
Erase QDQ pattern.
1719
"""
1820

19-
def __init__(self):
21+
def __init__(self, edge_program: torch.export.ExportedProgram, force_fold=False):
2022
super(FoldQDQ, self).__init__()
23+
self.edge_program = edge_program
24+
self.force_fold = force_fold
2125

22-
def _fold(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
26+
def _annotate_bypass(self, node):
27+
node.meta[QCOM_BYPASS_NODE] = True
28+
for arg in node.args:
29+
if isinstance(arg, torch.fx.Node) and arg.op == "call_function":
30+
self._annotate_bypass(arg)
31+
32+
def _fold_dq(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
2333
# remove dq
2434
for n in graph_module.graph.nodes:
2535
user_list = list(n.users.keys())
2636
if n.target not in dq_ops:
2737
continue
28-
for user_n in user_list:
29-
user_n.replace_input_with(n, n.args[0])
30-
graph_module.graph.erase_node(n)
3138

39+
# skip parameters & buffers
40+
if not self.force_fold and is_parameter(n.args[0], self.edge_program):
41+
self._annotate_bypass(n)
42+
else:
43+
for user_n in user_list:
44+
user_n.replace_input_with(n, n.args[0])
45+
graph_module.graph.erase_node(n)
46+
47+
def _fold_q(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
3248
# remove q
3349
for n in graph_module.graph.nodes:
3450
if n.target not in q_ops:
3551
continue
52+
3653
to_be_removed = [n]
3754
source_n = n.args[0]
3855

@@ -57,7 +74,8 @@ def _fold(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
5774
graph_module.graph.erase_node(n)
5875

5976
def call(self, graph_module: torch.fx.GraphModule):
60-
self._fold(graph_module)
77+
self._fold_dq(graph_module)
78+
self._fold_q(graph_module)
6179
graph_module.recompile()
6280
dead_code_elimination_pass(graph_module)
6381
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)