Skip to content

Commit 2f14536

Browse files
committed
Update base for Update on "[ET-VK] Adding batch processing in x axis to conv2d dw shader by caching input texel for reuse."
This diff adds batch processing in the x axis to the conv2d dw shader by reusing input texel overlapping between consecutive tiles. The changes include modifying the glsl code for the conv2d dw output tile, adding a new parameter to the yaml file, and modifying the Convolution.cpp file to use the new parameter. Differential Revision: [D67868671](https://our.internmc.facebook.com/intern/diff/D67868671/) [ghstack-poisoned]
2 parents bed2032 + 2e24b4e commit 2f14536

File tree

138 files changed

+2094
-1244
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

138 files changed

+2094
-1244
lines changed

.ci/docker/ci_commit_pins/buck2.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2024-05-15
1+
2024-12-16
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import itertools
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass_utils import create_node
11+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
from torch.fx import GraphModule
15+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
16+
17+
18+
class AnnotateDecomposedMatmulPass(ExportPass):
19+
"""
20+
torch.matmul can be decomposed in many ways, for instance:
21+
dq -> matmul -> q can become
22+
dq -> repeat -> view -> bmm -> view -> dq which makes quantization folding
23+
difficult. This helper function find all matmul partitions and annotate its
24+
matmul-op (can be mm or bmm).
25+
"""
26+
27+
def call(self, graph_module: GraphModule) -> PassResult:
28+
matmul_partitions = get_source_partitions(
29+
graph_module.graph,
30+
[
31+
torch.matmul,
32+
],
33+
None,
34+
)
35+
matmul_partitions = list(
36+
itertools.chain.from_iterable(matmul_partitions.values())
37+
)
38+
matmul_targets = {
39+
exir_ops.edge.aten.mm.default,
40+
exir_ops.edge.aten.bmm.default,
41+
}
42+
for partition in matmul_partitions:
43+
quantized_input = all(
44+
input_node.target == dq_op for input_node in partition.input_nodes
45+
)
46+
matmul_node = [
47+
node for node in partition.nodes if node.target in matmul_targets
48+
][0]
49+
if quantized_input:
50+
matmul_args = matmul_node.all_input_nodes
51+
for i in range(len(matmul_args)):
52+
input_node = partition.input_nodes[i]
53+
matmul_input_node = matmul_args[i]
54+
# Remove partition input dq-node
55+
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
56+
graph_module.graph.erase_node(input_node)
57+
input_node_qargs = input_node.args[1:]
58+
with graph_module.graph.inserting_before(matmul_node):
59+
# Create new dq-node before matmul
60+
dq_node = create_node(
61+
graph=graph_module.graph,
62+
op_target=dq_op,
63+
)
64+
dq_node.args = (matmul_input_node, *input_node_qargs)
65+
matmul_node.replace_input_with(matmul_input_node, dq_node)
66+
67+
partition_output = list(partition.output_nodes[0].users)[0]
68+
quantized_output = partition_output.target == q_op
69+
if quantized_output:
70+
output_node_qargs = partition_output.args[1:]
71+
with graph_module.graph.inserting_after(matmul_node):
72+
# Create q-node after matmul
73+
q_node = create_node(
74+
graph=graph_module.graph,
75+
op_target=q_op,
76+
)
77+
matmul_node.replace_all_uses_with(q_node)
78+
q_node.args = (matmul_node, *output_node_qargs)
79+
# Remove partition output q-node
80+
partition_output.replace_all_uses_with(
81+
partition_output.all_input_nodes[0]
82+
)
83+
graph_module.graph.erase_node(partition_output)
84+
85+
# retrace the graph to update the fake tensor types
86+
graph_module = super().call(graph_module).graph_module
87+
88+
graph_module.recompile()
89+
return PassResult(graph_module, True)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
1212
AnnotateChannelsLastDimOrder,
1313
)
14+
from executorch.backends.arm._passes.annotate_decomposed_matmul import (
15+
AnnotateDecomposedMatmulPass,
16+
)
1417
from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
1518
from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass
1619
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
@@ -32,7 +35,9 @@
3235
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
3336
FoldAndAnnotateQParamsPass,
3437
QuantizeFullArgument,
38+
RetraceFoldedDtypesPass,
3539
)
40+
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
3641
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
3742
KeepDimsFalseToSqueezePass,
3843
)
@@ -67,24 +72,15 @@ def transform_to_backend_pipeline(
6772
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
6873
):
6974
"""Apply passes before transforming program to backend"""
70-
self.add_pass(CastInt64ToInt32Pass(exported_program))
75+
self.add_pass(DecomposeLinearPass())
7176
self.add_pass(RemoveGetItemPass())
72-
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
73-
self.add_pass(SizeAdjustConv2DPass())
74-
self.add_pass(RemoveClonePass())
75-
self.add_pass(ConvertExpandCopyToRepeatPass())
7677
self.add_pass(DecomposeLayerNormPass())
77-
self.add_pass(UnsqueezeBeforeRepeatPass())
7878
self.add_pass(DecomposeVarPass())
7979
self.add_pass(ConvertMeanDimToAveragePool())
8080
self.add_pass(DecomposeMeanDimPass())
81-
self.add_pass(MatchArgRanksPass(exported_program))
82-
self.add_pass(DecomposeDivPass())
83-
self.add_pass(KeepDimsFalseToSqueezePass())
8481
self.add_pass(ConvertSplitToSlicePass())
85-
self.add_pass(Conv1dUnsqueezePass(exported_program))
86-
self.add_pass(DecomposeSoftmaxesPass())
87-
self.add_pass(DecomposeLinearPass())
82+
# TODO MLETORCH-558
83+
self.add_pass(AnnotateDecomposedMatmulPass())
8884
self.add_pass(QuantizeFullArgument())
8985
self.add_pass(
9086
FoldAndAnnotateQParamsPass(
@@ -93,11 +89,49 @@ def transform_to_backend_pipeline(
9389
exir_ops.edge.aten.maximum.default,
9490
exir_ops.edge.aten.add.Tensor,
9591
exir_ops.edge.aten.avg_pool2d.default,
92+
exir_ops.edge.aten.bmm.default,
93+
exir_ops.edge.aten.cat.default,
9694
exir_ops.edge.aten.convolution.default,
95+
exir_ops.edge.aten.clone.default,
96+
exir_ops.edge.aten.exp.default,
97+
exir_ops.edge.aten.expand_copy.default,
9798
exir_ops.edge.aten.full.default,
99+
exir_ops.edge.aten.hardtanh.default,
100+
exir_ops.edge.aten.log.default,
101+
exir_ops.edge.aten.max_pool2d.default,
102+
exir_ops.edge.aten.mm.default,
103+
exir_ops.edge.aten.mul.Tensor,
104+
exir_ops.edge.aten.permute_copy.default,
105+
exir_ops.edge.aten.reciprocal.default,
106+
exir_ops.edge.aten.relu.default,
107+
exir_ops.edge.aten.repeat.default,
108+
exir_ops.edge.aten.rsqrt.default,
109+
exir_ops.edge.aten.select_copy.int,
110+
exir_ops.edge.aten.sigmoid.default,
111+
exir_ops.edge.aten.slice_copy.Tensor,
112+
exir_ops.edge.aten.squeeze_copy.dims,
113+
exir_ops.edge.aten.sub.Tensor,
114+
exir_ops.edge.aten.sum.dim_IntList,
115+
exir_ops.edge.aten.tanh.default,
116+
exir_ops.edge.aten.unsqueeze_copy.default,
117+
exir_ops.edge.aten.upsample_nearest2d.vec,
118+
exir_ops.edge.aten.view_copy.default,
98119
]
99120
)
100121
)
122+
self.add_pass(RetraceFoldedDtypesPass())
123+
self.add_pass(InsertTableOpsPass(exported_program))
124+
self.add_pass(ConvertExpandCopyToRepeatPass())
125+
self.add_pass(UnsqueezeBeforeRepeatPass())
126+
self.add_pass(CastInt64ToInt32Pass(exported_program))
127+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
128+
self.add_pass(SizeAdjustConv2DPass())
129+
self.add_pass(RemoveClonePass())
130+
self.add_pass(MatchArgRanksPass(exported_program))
131+
self.add_pass(DecomposeDivPass())
132+
self.add_pass(KeepDimsFalseToSqueezePass())
133+
self.add_pass(Conv1dUnsqueezePass(exported_program))
134+
self.add_pass(DecomposeSoftmaxesPass())
101135
for spec in compile_spec:
102136
if spec.key == "permute_memory_format":
103137
memory_format = spec.value.decode()

backends/arm/_passes/conv1d_unsqueeze_pass.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
from executorch.backends.arm._passes.arm_pass_utils import (
1313
create_node,
1414
get_param_tensor,
15-
insert_q_dq_pair,
1615
is_param_node,
1716
)
18-
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
1917
from executorch.exir import ExportedProgram
2018
from executorch.exir.dialects._ops import ops as exir_ops
2119
from executorch.exir.pass_base import ExportPass, PassResult
@@ -27,10 +25,8 @@ class Conv1dUnsqueezePass(ExportPass):
2725
supports 2d and 3d convolution. This is done by modifying the graph to do the
2826
following:
2927
1) unsqueeze the convolution's input from 3d to 4d
30-
2) if the input to unsqueeze is quantized, insert q/dq-pair after unsqueeze
31-
3) perform a conv2d (with a modified version of the original conv1d args)
32-
4) squeeze the output back down to 3d.
33-
5) if all users of squeeze are quantized, insert q/dq-pair before squeeze
28+
2) perform a conv2d (with a modified version of the original conv1d args)
29+
3) squeeze the output back down to 3d.
3430
"""
3531

3632
def __init__(self, exported_program: ExportedProgram) -> None:
@@ -94,8 +90,6 @@ def call(self, graph_module: torch.fx.GraphModule):
9490
continue
9591

9692
kernel_node = node.args[1]
97-
if kernel_node.target == dq_op:
98-
kernel_node = kernel_node.args[0]
9993

10094
if not is_param_node(self.exported_program, kernel_node):
10195
raise AssertionError(
@@ -131,11 +125,6 @@ def call(self, graph_module: torch.fx.GraphModule):
131125
)
132126
node.replace_input_with(input_node, unsqueeze_before)
133127

134-
# If Quantized we must insert unsqueeze --> q --> dq --> node
135-
if input_node.target == dq_op:
136-
q_params = input_node.args[1:]
137-
insert_q_dq_pair(graph, unsqueeze_before, q_params)
138-
139128
with graph.inserting_after(node):
140129
squeeze_after = create_node(
141130
graph,
@@ -151,13 +140,6 @@ def call(self, graph_module: torch.fx.GraphModule):
151140
for user in original_users:
152141
user.replace_input_with(node, squeeze_after)
153142

154-
# If quantized, insert conv2d --> q --> dq --> squeeze
155-
if all(
156-
original_user.target == q_op for original_user in original_users
157-
):
158-
q_params = original_users[0].args[1:]
159-
insert_q_dq_pair(graph, node, q_params)
160-
161143
graph_module.recompile()
162144
# Since we are overriding "call", we need to call the parent's "call"
163145
# to retrace the graph and regenerate metadata

0 commit comments

Comments
 (0)