Skip to content

Commit 98cac1f

Browse files
committed
Arm backend: Add support for grouped convolution
Grouped convolution is lowered as separate convolutions on different slices of the input and weights in a new DecomposeGroupedConv pass. Tested in two new tests in test_conv2d. Fuse constant ops pass is additionally updated to make sure all removed placeholders are deleted. Signed-off-by: Adrian Lundell <[email protected]> Change-Id: I280ba9342bb92b826152e49c9570ed8715bc457f
1 parent 994752e commit 98cac1f

File tree

5 files changed

+173
-10
lines changed

5 files changed

+173
-10
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .decompose_div_pass import DecomposeDivPass # noqa
2525
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
2626
from .decompose_gelu_pass import DecomposeGeluPass # noqa
27+
from .decompose_grouped_conv import DecomposeGroupedConv # noqa
2728
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
2829
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
2930
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
DecomposeDivPass,
2828
DecomposeEmbeddingPass,
2929
DecomposeGeluPass,
30+
DecomposeGroupedConv,
3031
DecomposeGroupNormPass,
3132
DecomposeLayerNormPass,
3233
DecomposeLeakyReLUPass,
@@ -117,6 +118,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
117118
self.add_pass(DecomposeLinearPass())
118119
self.add_pass(ComputeConstantOpsAOT(exported_program))
119120

121+
self.add_pass(DecomposeGroupedConv())
120122
self.add_pass(RemoveClonePass())
121123
self.add_pass(SizeAdjustConv2DPass())
122124
self.add_pass(ConvertExpandCopyToRepeatPass())
@@ -174,6 +176,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
174176
self.add_pass(MatchArgRanksPass(exported_program))
175177
self.add_pass(ComputeConstantOpsAOT(exported_program))
176178

179+
self.add_pass(DecomposeGroupedConv())
177180
self.add_pass(RemoveClonePass())
178181
self.add_pass(SizeAdjustConv2DPass())
179182
self.add_pass(ConvertExpandCopyToRepeatPass())
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from copy import copy
7+
8+
import torch
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
12+
13+
class DecomposeGroupedConv(ExportPass):
14+
"""
15+
Splits a grouped convolution which is not supported by TOSA into multiple
16+
convolutions using slice->conv->cat.
17+
18+
Before pass:
19+
x = conv(input, weight, bias, groups = 2)
20+
21+
After pass:
22+
input1 = slice(input)
23+
weight1 = slice(weight)
24+
bias1 = slice(bias)
25+
x1 = conv(input1, weight1, bias1)
26+
27+
input2 = slice(input)
28+
weight2 = slice(weight)
29+
bias2 = slice(bias)
30+
x2 = conv(input2, weight2, bias2)
31+
32+
x = cat(x1, x2)
33+
"""
34+
35+
@staticmethod
36+
def _get_decomposition(op):
37+
match op:
38+
case exir_ops.edge.aten.convolution.default:
39+
return (
40+
exir_ops.edge.aten.slice_copy.Tensor,
41+
exir_ops.edge.aten.convolution.default,
42+
exir_ops.edge.aten.cat.default,
43+
)
44+
case torch.ops.aten.conv2d.default:
45+
return (
46+
torch.ops.aten.slice_copy.Tensor,
47+
torch.ops.aten.conv2d.default,
48+
torch.ops.aten.cat.default,
49+
)
50+
case _:
51+
raise RuntimeError("Unvalid op for grouped conv decomposition.")
52+
53+
def call_operator(self, op, args, kwargs, meta):
54+
if op == exir_ops.edge.aten.convolution.default:
55+
groups = args[8]
56+
transposed = args[6]
57+
elif op == torch.ops.aten.conv2d.default:
58+
groups = args[6]
59+
transposed = False
60+
else:
61+
return super().call_operator(op, args, kwargs, meta)
62+
63+
if groups == 1 or transposed:
64+
return super().call_operator(op, args, kwargs, meta)
65+
66+
input_node = args[0]
67+
if input_node.data.shape[1] == groups:
68+
# This is a depthwise convolution which is handled elsewhere
69+
return super().call_operator(op, args, kwargs, meta)
70+
71+
weight_node = args[1]
72+
bias_node = args[2]
73+
74+
input_slice_size = weight_node.data.shape[1]
75+
output_slice_size = weight_node.data.shape[0] // groups
76+
77+
no_q_dq_meta = copy(meta)
78+
no_q_dq_meta.data = {}
79+
no_q_dq_meta.data = {}
80+
81+
slice_op, conv_op, cat_op = DecomposeGroupedConv._get_decomposition(op)
82+
83+
input_slices = []
84+
for i in range(groups):
85+
start_index = i * input_slice_size
86+
stop_index = (i + 1) * input_slice_size
87+
slice_args = (input_node, 1, start_index, stop_index)
88+
89+
input_slices.append(
90+
super().call_operator(slice_op, slice_args, kwargs, no_q_dq_meta)
91+
)
92+
93+
filter_slices = []
94+
for i in range(groups):
95+
start_index = i * output_slice_size
96+
stop_index = (i + 1) * output_slice_size
97+
slice_args = (weight_node, 0, start_index, stop_index)
98+
99+
filter_slices.append(
100+
super().call_operator(slice_op, slice_args, kwargs, no_q_dq_meta)
101+
)
102+
103+
bias_slices = []
104+
for i in range(groups):
105+
if bias_node is None:
106+
bias_slices.append(None)
107+
else:
108+
109+
start_index = i * output_slice_size
110+
stop_index = (i + 1) * output_slice_size
111+
slice_args = (bias_node, 0, start_index, stop_index)
112+
113+
bias_slices.append(
114+
super().call_operator(slice_op, slice_args, kwargs, no_q_dq_meta)
115+
)
116+
117+
output_slices = []
118+
for input_slice, filter_slice, bias_slice in zip(
119+
input_slices, filter_slices, bias_slices
120+
):
121+
122+
if op == exir_ops.edge.aten.convolution.default:
123+
conv_args = (input_slice, filter_slice, bias_slice, *args[3:8], 1)
124+
elif op == torch.ops.aten.conv2d.default:
125+
conv_args = (input_slice, filter_slice, bias_slice, *args[3:6], 1)
126+
else:
127+
raise RuntimeError("Unvalid op for grouped conv decomposition.")
128+
129+
output_slices.append(
130+
super().call_operator(conv_op, conv_args, kwargs, meta)
131+
)
132+
133+
cat_args = (output_slices, 1)
134+
return super().call_operator(cat_op, cat_args, kwargs, no_q_dq_meta)

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _fuse_nodes(self, node) -> bool:
9898

9999
def call(self, graph_module):
100100
modified = False
101-
input_nodes_to_delete = []
101+
input_nodes_to_maybe_delete = set()
102102
for node in graph_module.graph.nodes:
103103
if node.op != "call_function":
104104
continue
@@ -128,22 +128,17 @@ def call(self, graph_module):
128128
)
129129
modified |= did_fuse
130130
graph_module.recompile() # Recompile needed to catch chains of constant ops
131-
input_nodes_to_delete.extend(
132-
[
133-
input_node
134-
for input_node in input_nodes
135-
if len(input_node.users) == 1
136-
]
137-
)
131+
input_nodes_to_maybe_delete.update(input_nodes)
138132
except Exception as e:
139133
logger.warning(
140134
f"\nFailed to fuse constant op {node.name} due to exception:\n{str(e)}"
141135
)
142136

143137
if modified:
144138
graph_module.graph.eliminate_dead_code()
145-
for input_node in input_nodes_to_delete:
146-
delete_constant_placeholder(self.exported_program, input_node)
139+
for input_node in input_nodes_to_maybe_delete:
140+
if len(input_node.users) == 0:
141+
delete_constant_placeholder(self.exported_program, input_node)
147142

148143
graph_module = super().call(graph_module).graph_module
149144

backends/arm/test/ops/test_conv2d.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,34 @@ def forward(self, x):
327327
batches=1,
328328
)
329329

330+
conv2d_groups = Conv2d(
331+
in_channels=12,
332+
out_channels=9,
333+
kernel_size=(3, 3),
334+
stride=1,
335+
padding=0,
336+
dilation=1,
337+
width=7,
338+
height=7,
339+
batches=1,
340+
groups=3,
341+
bias=False,
342+
)
343+
344+
conv2d_groups_bias = Conv2d(
345+
in_channels=15,
346+
out_channels=5,
347+
kernel_size=(3, 3),
348+
stride=1,
349+
padding=0,
350+
dilation=1,
351+
width=7,
352+
height=7,
353+
batches=1,
354+
groups=5,
355+
bias=True,
356+
)
357+
330358
# Shenanigan to get a nicer output when test fails. With unittest it looks like:
331359
# FAIL: test_convolution_2d_tosa_BI_2_3x3_1x3x12x12_st2_pd1
332360
test_modules = {
@@ -348,6 +376,8 @@ def forward(self, x):
348376
"3x3_1x3x224x224_st2_pd1": lambda: conv2d_3x3_1x3x224x224_st2_pd1,
349377
"two_conv2d_nobias": lambda: two_conv2d_nobias,
350378
"two_conv2d": lambda: two_conv2d,
379+
"groups": lambda: conv2d_groups,
380+
"groups_bias": lambda: conv2d_groups_bias,
351381
}
352382

353383
fvp_xfails = {

0 commit comments

Comments
 (0)