Skip to content

Commit 0fb5b0d

Browse files
committed
Update on "[ET-VK] Changing all conv 2d pw ints from uint16 to int since it slightly improves perf."
This diff changes all integers in conv 2d pw op shader from uint16 to int in the Vulkan backend of Executorch. The change is made to improve performance since the shader does not appear to be register bound. Differential Revision: [D67906023](https://our.internmc.facebook.com/intern/diff/D67906023/) [ghstack-poisoned]
2 parents 7469edb + 0019161 commit 0fb5b0d

File tree

140 files changed

+2096
-1265
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

140 files changed

+2096
-1265
lines changed

.ci/docker/ci_commit_pins/buck2.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2024-05-15
1+
2024-12-16
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import itertools
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass_utils import create_node
11+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
from torch.fx import GraphModule
15+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
16+
17+
18+
class AnnotateDecomposedMatmulPass(ExportPass):
19+
"""
20+
torch.matmul can be decomposed in many ways, for instance:
21+
dq -> matmul -> q can become
22+
dq -> repeat -> view -> bmm -> view -> dq which makes quantization folding
23+
difficult. This helper function find all matmul partitions and annotate its
24+
matmul-op (can be mm or bmm).
25+
"""
26+
27+
def call(self, graph_module: GraphModule) -> PassResult:
28+
matmul_partitions = get_source_partitions(
29+
graph_module.graph,
30+
[
31+
torch.matmul,
32+
],
33+
None,
34+
)
35+
matmul_partitions = list(
36+
itertools.chain.from_iterable(matmul_partitions.values())
37+
)
38+
matmul_targets = {
39+
exir_ops.edge.aten.mm.default,
40+
exir_ops.edge.aten.bmm.default,
41+
}
42+
for partition in matmul_partitions:
43+
quantized_input = all(
44+
input_node.target == dq_op for input_node in partition.input_nodes
45+
)
46+
matmul_node = [
47+
node for node in partition.nodes if node.target in matmul_targets
48+
][0]
49+
if quantized_input:
50+
matmul_args = matmul_node.all_input_nodes
51+
for i in range(len(matmul_args)):
52+
input_node = partition.input_nodes[i]
53+
matmul_input_node = matmul_args[i]
54+
# Remove partition input dq-node
55+
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
56+
graph_module.graph.erase_node(input_node)
57+
input_node_qargs = input_node.args[1:]
58+
with graph_module.graph.inserting_before(matmul_node):
59+
# Create new dq-node before matmul
60+
dq_node = create_node(
61+
graph=graph_module.graph,
62+
op_target=dq_op,
63+
)
64+
dq_node.args = (matmul_input_node, *input_node_qargs)
65+
matmul_node.replace_input_with(matmul_input_node, dq_node)
66+
67+
partition_output = list(partition.output_nodes[0].users)[0]
68+
quantized_output = partition_output.target == q_op
69+
if quantized_output:
70+
output_node_qargs = partition_output.args[1:]
71+
with graph_module.graph.inserting_after(matmul_node):
72+
# Create q-node after matmul
73+
q_node = create_node(
74+
graph=graph_module.graph,
75+
op_target=q_op,
76+
)
77+
matmul_node.replace_all_uses_with(q_node)
78+
q_node.args = (matmul_node, *output_node_qargs)
79+
# Remove partition output q-node
80+
partition_output.replace_all_uses_with(
81+
partition_output.all_input_nodes[0]
82+
)
83+
graph_module.graph.erase_node(partition_output)
84+
85+
# retrace the graph to update the fake tensor types
86+
graph_module = super().call(graph_module).graph_module
87+
88+
graph_module.recompile()
89+
return PassResult(graph_module, True)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
1212
AnnotateChannelsLastDimOrder,
1313
)
14+
from executorch.backends.arm._passes.annotate_decomposed_matmul import (
15+
AnnotateDecomposedMatmulPass,
16+
)
1417
from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
1518
from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass
1619
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
@@ -32,7 +35,9 @@
3235
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
3336
FoldAndAnnotateQParamsPass,
3437
QuantizeFullArgument,
38+
RetraceFoldedDtypesPass,
3539
)
40+
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
3641
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
3742
KeepDimsFalseToSqueezePass,
3843
)
@@ -67,24 +72,15 @@ def transform_to_backend_pipeline(
6772
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
6873
):
6974
"""Apply passes before transforming program to backend"""
70-
self.add_pass(CastInt64ToInt32Pass(exported_program))
75+
self.add_pass(DecomposeLinearPass())
7176
self.add_pass(RemoveGetItemPass())
72-
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
73-
self.add_pass(SizeAdjustConv2DPass())
74-
self.add_pass(RemoveClonePass())
75-
self.add_pass(ConvertExpandCopyToRepeatPass())
7677
self.add_pass(DecomposeLayerNormPass())
77-
self.add_pass(UnsqueezeBeforeRepeatPass())
7878
self.add_pass(DecomposeVarPass())
7979
self.add_pass(ConvertMeanDimToAveragePool())
8080
self.add_pass(DecomposeMeanDimPass())
81-
self.add_pass(MatchArgRanksPass(exported_program))
82-
self.add_pass(DecomposeDivPass())
83-
self.add_pass(KeepDimsFalseToSqueezePass())
8481
self.add_pass(ConvertSplitToSlicePass())
85-
self.add_pass(Conv1dUnsqueezePass(exported_program))
86-
self.add_pass(DecomposeSoftmaxesPass())
87-
self.add_pass(DecomposeLinearPass())
82+
# TODO MLETORCH-558
83+
self.add_pass(AnnotateDecomposedMatmulPass())
8884
self.add_pass(QuantizeFullArgument())
8985
self.add_pass(
9086
FoldAndAnnotateQParamsPass(
@@ -93,11 +89,49 @@ def transform_to_backend_pipeline(
9389
exir_ops.edge.aten.maximum.default,
9490
exir_ops.edge.aten.add.Tensor,
9591
exir_ops.edge.aten.avg_pool2d.default,
92+
exir_ops.edge.aten.bmm.default,
93+
exir_ops.edge.aten.cat.default,
9694
exir_ops.edge.aten.convolution.default,
95+
exir_ops.edge.aten.clone.default,
96+
exir_ops.edge.aten.exp.default,
97+
exir_ops.edge.aten.expand_copy.default,
9798
exir_ops.edge.aten.full.default,
99+
exir_ops.edge.aten.hardtanh.default,
100+
exir_ops.edge.aten.log.default,
101+
exir_ops.edge.aten.max_pool2d.default,
102+
exir_ops.edge.aten.mm.default,
103+
exir_ops.edge.aten.mul.Tensor,
104+
exir_ops.edge.aten.permute_copy.default,
105+
exir_ops.edge.aten.reciprocal.default,
106+
exir_ops.edge.aten.relu.default,
107+
exir_ops.edge.aten.repeat.default,
108+
exir_ops.edge.aten.rsqrt.default,
109+
exir_ops.edge.aten.select_copy.int,
110+
exir_ops.edge.aten.sigmoid.default,
111+
exir_ops.edge.aten.slice_copy.Tensor,
112+
exir_ops.edge.aten.squeeze_copy.dims,
113+
exir_ops.edge.aten.sub.Tensor,
114+
exir_ops.edge.aten.sum.dim_IntList,
115+
exir_ops.edge.aten.tanh.default,
116+
exir_ops.edge.aten.unsqueeze_copy.default,
117+
exir_ops.edge.aten.upsample_nearest2d.vec,
118+
exir_ops.edge.aten.view_copy.default,
98119
]
99120
)
100121
)
122+
self.add_pass(RetraceFoldedDtypesPass())
123+
self.add_pass(InsertTableOpsPass(exported_program))
124+
self.add_pass(ConvertExpandCopyToRepeatPass())
125+
self.add_pass(UnsqueezeBeforeRepeatPass())
126+
self.add_pass(CastInt64ToInt32Pass(exported_program))
127+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
128+
self.add_pass(SizeAdjustConv2DPass())
129+
self.add_pass(RemoveClonePass())
130+
self.add_pass(MatchArgRanksPass(exported_program))
131+
self.add_pass(DecomposeDivPass())
132+
self.add_pass(KeepDimsFalseToSqueezePass())
133+
self.add_pass(Conv1dUnsqueezePass(exported_program))
134+
self.add_pass(DecomposeSoftmaxesPass())
101135
for spec in compile_spec:
102136
if spec.key == "permute_memory_format":
103137
memory_format = spec.value.decode()

backends/arm/_passes/conv1d_unsqueeze_pass.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
from executorch.backends.arm._passes.arm_pass_utils import (
1313
create_node,
1414
get_param_tensor,
15-
insert_q_dq_pair,
1615
is_param_node,
1716
)
18-
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
1917
from executorch.exir import ExportedProgram
2018
from executorch.exir.dialects._ops import ops as exir_ops
2119
from executorch.exir.pass_base import ExportPass, PassResult
@@ -27,10 +25,8 @@ class Conv1dUnsqueezePass(ExportPass):
2725
supports 2d and 3d convolution. This is done by modifying the graph to do the
2826
following:
2927
1) unsqueeze the convolution's input from 3d to 4d
30-
2) if the input to unsqueeze is quantized, insert q/dq-pair after unsqueeze
31-
3) perform a conv2d (with a modified version of the original conv1d args)
32-
4) squeeze the output back down to 3d.
33-
5) if all users of squeeze are quantized, insert q/dq-pair before squeeze
28+
2) perform a conv2d (with a modified version of the original conv1d args)
29+
3) squeeze the output back down to 3d.
3430
"""
3531

3632
def __init__(self, exported_program: ExportedProgram) -> None:
@@ -94,8 +90,6 @@ def call(self, graph_module: torch.fx.GraphModule):
9490
continue
9591

9692
kernel_node = node.args[1]
97-
if kernel_node.target == dq_op:
98-
kernel_node = kernel_node.args[0]
9993

10094
if not is_param_node(self.exported_program, kernel_node):
10195
raise AssertionError(
@@ -131,11 +125,6 @@ def call(self, graph_module: torch.fx.GraphModule):
131125
)
132126
node.replace_input_with(input_node, unsqueeze_before)
133127

134-
# If Quantized we must insert unsqueeze --> q --> dq --> node
135-
if input_node.target == dq_op:
136-
q_params = input_node.args[1:]
137-
insert_q_dq_pair(graph, unsqueeze_before, q_params)
138-
139128
with graph.inserting_after(node):
140129
squeeze_after = create_node(
141130
graph,
@@ -151,13 +140,6 @@ def call(self, graph_module: torch.fx.GraphModule):
151140
for user in original_users:
152141
user.replace_input_with(node, squeeze_after)
153142

154-
# If quantized, insert conv2d --> q --> dq --> squeeze
155-
if all(
156-
original_user.target == q_op for original_user in original_users
157-
):
158-
q_params = original_users[0].args[1:]
159-
insert_q_dq_pair(graph, node, q_params)
160-
161143
graph_module.recompile()
162144
# Since we are overriding "call", we need to call the parent's "call"
163145
# to retrace the graph and regenerate metadata

0 commit comments

Comments
 (0)