Skip to content

Commit 601a487

Browse files
committed
Update on "[4/N] Add backend options map"
This is to manage the backend <-> BackendOptions map. Users will create the bakcend options map, and ET runtime will read the backend name, and dispatch the list of backend options to each backend. exported-using-ghexport Differential Revision: [D76149466](https://our.internmc.facebook.com/intern/diff/D76149466/) Differential Revision: [D76149466](https://our.internmc.facebook.com/intern/diff/D76149466) [ghstack-poisoned]
2 parents 390aed6 + 667b39a commit 601a487

File tree

91 files changed

+3137
-1157
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+3137
-1157
lines changed

.ci/scripts/wheel/pre_build_script.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ set -euxo pipefail
1414
# which does install them. Though we'd need to disable build isolation to be
1515
# able to see the installed torch package.
1616

17-
"${GITHUB_WORKSPACE}/${REPOSITORY}/install_requirements.sh"
17+
"${GITHUB_WORKSPACE}/${REPOSITORY}/install_requirements.sh" --example

CODEOWNERS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
/extension/flat_tensor @lucylq
4949
/extension/gguf_util @larryliu0820
5050
/extension/kernel_util @kimishpatel @manuelcandales @swolchok
51-
/extension/llm @jackzhxng @larryliu0820 @swolchok
51+
/extension/llm @jackzhxng @larryliu0820 @swolchok @mergennachin
5252
/extension/memory_allocator @JacobSzwejbka @swolchok
5353
/extension/module @shoumikhin
5454
/extension/parallel @kimishpatel @swolchok

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
2121
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2222
from .convert_to_clamp import ConvertToClampPass # noqa
23+
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa
2324
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
2425
from .decompose_div_pass import DecomposeDivPass # noqa
2526
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
2627
from .decompose_gelu_pass import DecomposeGeluPass # noqa
28+
from .decompose_grouped_conv import DecomposeGroupedConv # noqa
2729
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
2830
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
2931
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
@@ -32,6 +34,7 @@
3234
from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa
3335
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
3436
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
37+
from .decompose_round_pass import DecomposeRoundPass # noqa
3538
from .decompose_select import DecomposeSelectPass # noqa
3639
from .decompose_silu_pass import DecomposeSiluPass # noqa
3740
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
ConvertSplitToSlicePass,
2424
ConvertSqueezesToViewPass,
2525
ConvertToClampPass,
26+
DecomposeAvgPool2d,
2627
DecomposeCosineSimilarityPass,
2728
DecomposeDivPass,
2829
DecomposeEmbeddingPass,
2930
DecomposeGeluPass,
31+
DecomposeGroupedConv,
3032
DecomposeGroupNormPass,
3133
DecomposeLayerNormPass,
3234
DecomposeLeakyReLUPass,
@@ -35,6 +37,7 @@
3537
DecomposeMaxPool2DPass,
3638
DecomposeMeanDimPass,
3739
DecomposeNotEqualPass,
40+
DecomposeRoundPass,
3841
DecomposeSelectPass,
3942
DecomposeSiluPass,
4043
DecomposeSoftmaxPass,
@@ -63,7 +66,6 @@
6366
UnsqueezeBeforeRepeatPass,
6467
UnsqueezeScalarPlaceholdersPass,
6568
)
66-
6769
from executorch.backends.arm.tosa_specification import (
6870
TosaLoweringContext,
6971
TosaSpecification,
@@ -115,8 +117,10 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
115117
if self.tosa_spec.is_U55_subset:
116118
self.add_pass(BroadcastArgsPass())
117119
self.add_pass(DecomposeLinearPass())
120+
self.add_pass(DecomposeAvgPool2d())
118121
self.add_pass(ComputeConstantOpsAOT(exported_program))
119122

123+
self.add_pass(DecomposeGroupedConv())
120124
self.add_pass(RemoveClonePass())
121125
self.add_pass(SizeAdjustConv2DPass())
122126
self.add_pass(ConvertExpandCopyToRepeatPass())
@@ -139,6 +143,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
139143
return self._transform(exported_program.graph_module)
140144

141145
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
146+
self.add_pass(DecomposeRoundPass())
142147
self.add_pass(DecomposeSqrtPass())
143148
self.add_pass(ConvertIntPowToMuls())
144149
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
@@ -172,8 +177,10 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
172177
self.add_pass(RetraceFoldedDtypesPass())
173178
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
174179
self.add_pass(MatchArgRanksPass(exported_program))
180+
self.add_pass(DecomposeAvgPool2d())
175181
self.add_pass(ComputeConstantOpsAOT(exported_program))
176182

183+
self.add_pass(DecomposeGroupedConv())
177184
self.add_pass(RemoveClonePass())
178185
self.add_pass(SizeAdjustConv2DPass())
179186
self.add_pass(ConvertExpandCopyToRepeatPass())
@@ -219,6 +226,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
219226
self.add_pass(InsertCastForOpsWithInt64InputPass())
220227
self.add_pass(DecomposeEmbeddingPass())
221228
self.add_pass(DecomposeScaledDotProductAttention())
229+
self.add_pass(DecomposeRoundPass())
222230
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
223231
self.add_pass(ScalarsToAttributePass())
224232
self.add_pass(DecomposeGroupNormPass())
@@ -232,6 +240,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
232240
self.add_pass(DecomposeLinearVectorNormPass())
233241
self.add_pass(DecomposeSqrtPass())
234242
self.add_pass(DecomposeSiluPass())
243+
self.add_pass(DecomposeAvgPool2d())
235244

236245
if self.tosa_spec.is_U55_subset:
237246
# Numerically stable softmax uses amax which is not supported on Ethos-U55
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import torch
8+
from executorch.backends.arm.operators.operator_validation_utils import (
9+
adjust_pooling_pad_if_needed,
10+
)
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass
13+
14+
edge_div_ops = (exir_ops.edge.aten.avg_pool2d.default,)
15+
aten_div_ops = (torch.ops.aten.avg_pool2d.default,)
16+
17+
18+
def get_decomposition(op) -> tuple:
19+
if op in edge_div_ops:
20+
return (
21+
exir_ops.edge.aten.full.default,
22+
exir_ops.edge.aten.cat.default,
23+
exir_ops.edge.aten.avg_pool2d.default,
24+
exir_ops.edge.aten.mul.Tensor,
25+
)
26+
if op in aten_div_ops:
27+
return (
28+
torch.ops.aten.full.default,
29+
torch.ops.aten.cat.default,
30+
torch.ops.aten.avg_pool2d.default,
31+
torch.ops.aten.mul.Tensor,
32+
)
33+
raise RuntimeError(f"Can't get div decomposition for op {op}")
34+
35+
36+
class DecomposeAvgPool2d(ExportPass):
37+
""" """
38+
39+
def call_operator(self, op, args, kwargs, meta):
40+
if op not in (edge_div_ops + aten_div_ops):
41+
return super().call_operator(op, args, kwargs, meta)
42+
43+
full_op, cat_op, avgpool_op, mul_op = get_decomposition(op)
44+
45+
x = args[0]
46+
kernel_h, kernel_w = args[1]
47+
kernel_size = kernel_h * kernel_w
48+
stride_h, stride_w = args[2]
49+
pad_h, pad_w = new_pad_h, new_pad_w = args[3] if len(args) > 3 else (0, 0)
50+
ceil_mode = args[4] if len(args) > 4 else False
51+
count_include_pad = args[5] if len(args) > 5 else True
52+
divisor_override = args[6] if len(args) > 6 else None
53+
54+
n, c, h, w = x.data.shape
55+
post_pad_w, post_pad_h = (0, 0)
56+
57+
# Count_include_pad == False means that we use a different divisor for edge elements
58+
# When divisor_override is set, this will be overriden anyways.
59+
# It is easier to replace a constant divisor, so set count_include_pad == True
60+
if divisor_override is not None:
61+
count_include_pad = True
62+
63+
# Add width padding manually if count_include_pad
64+
if count_include_pad and pad_w > 0:
65+
pre_pad_shape = [n, c, h, pad_w]
66+
pre_pad = super().call_operator(full_op, (pre_pad_shape, 0.0), kwargs, meta)
67+
68+
if ceil_mode and divisor_override is None:
69+
post_pad_w = pad_w
70+
else:
71+
post_pad_w = adjust_pooling_pad_if_needed(
72+
w, kernel_w, stride_w, pad_w, ceil_mode
73+
)
74+
75+
if post_pad_w > 0:
76+
post_pad_shape = [n, c, h, post_pad_w]
77+
post_pad = super().call_operator(
78+
full_op, (post_pad_shape, 0.0), kwargs, meta
79+
)
80+
cat_nodes = [pre_pad, x, post_pad]
81+
else:
82+
cat_nodes = [pre_pad, x]
83+
84+
x = super().call_operator(cat_op, (cat_nodes, 3), kwargs, meta)
85+
new_pad_w = 0
86+
87+
# Add height padding manually if count_include_pad
88+
if count_include_pad and pad_h > 0:
89+
pre_pad_shape = [n, c, pad_h, w + pad_w + post_pad_w]
90+
pre_pad = super().call_operator(full_op, (pre_pad_shape, 0.0), kwargs, meta)
91+
92+
if ceil_mode and divisor_override is None:
93+
post_pad_h = pad_h
94+
else:
95+
post_pad_h = adjust_pooling_pad_if_needed(
96+
h, kernel_h, stride_h, pad_h, ceil_mode
97+
)
98+
99+
if post_pad_h > 0:
100+
post_pad_shape = [n, c, post_pad_h, w + pad_w + post_pad_w]
101+
post_pad = super().call_operator(
102+
full_op, (post_pad_shape, 0.0), kwargs, meta
103+
)
104+
cat_nodes = [pre_pad, x, post_pad]
105+
else:
106+
cat_nodes = [pre_pad, x]
107+
108+
x = super().call_operator(cat_op, (cat_nodes, 2), kwargs, meta)
109+
new_pad_h = 0
110+
111+
avgpool_args = (x, args[1], args[2], [new_pad_h, new_pad_w], ceil_mode, False)
112+
x = super().call_operator(avgpool_op, avgpool_args, kwargs, meta)
113+
114+
# Multiply by factor (kernel_size / divisor_override) if divisor_override
115+
if divisor_override is not None and divisor_override != kernel_size:
116+
override_multiplier = super().call_operator(
117+
full_op, ([1, 1, 1, 1], kernel_size / divisor_override), kwargs, meta
118+
)
119+
x = super().call_operator(mul_op, (x, override_multiplier), kwargs, meta)
120+
121+
return x
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from copy import copy
7+
8+
import torch
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
12+
13+
class DecomposeGroupedConv(ExportPass):
14+
"""
15+
Splits a grouped convolution which is not supported by TOSA into multiple
16+
convolutions using slice->conv->cat.
17+
18+
Before pass:
19+
x = conv(input, weight, bias, groups = 2)
20+
21+
After pass:
22+
input1 = slice(input)
23+
weight1 = slice(weight)
24+
bias1 = slice(bias)
25+
x1 = conv(input1, weight1, bias1)
26+
27+
input2 = slice(input)
28+
weight2 = slice(weight)
29+
bias2 = slice(bias)
30+
x2 = conv(input2, weight2, bias2)
31+
32+
x = cat(x1, x2)
33+
"""
34+
35+
@staticmethod
36+
def _get_decomposition(op):
37+
match op:
38+
case exir_ops.edge.aten.convolution.default:
39+
return (
40+
exir_ops.edge.aten.slice_copy.Tensor,
41+
exir_ops.edge.aten.convolution.default,
42+
exir_ops.edge.aten.cat.default,
43+
)
44+
case torch.ops.aten.conv2d.default:
45+
return (
46+
torch.ops.aten.slice_copy.Tensor,
47+
torch.ops.aten.conv2d.default,
48+
torch.ops.aten.cat.default,
49+
)
50+
case _:
51+
raise RuntimeError("Unvalid op for grouped conv decomposition.")
52+
53+
def call_operator(self, op, args, kwargs, meta):
54+
if op == exir_ops.edge.aten.convolution.default:
55+
groups = args[8]
56+
transposed = args[6]
57+
elif op == torch.ops.aten.conv2d.default:
58+
groups = args[6]
59+
transposed = False
60+
else:
61+
return super().call_operator(op, args, kwargs, meta)
62+
63+
if groups == 1 or transposed:
64+
return super().call_operator(op, args, kwargs, meta)
65+
66+
input_node = args[0]
67+
if input_node.data.shape[1] == groups:
68+
# This is a depthwise convolution which is handled elsewhere
69+
return super().call_operator(op, args, kwargs, meta)
70+
71+
weight_node = args[1]
72+
bias_node = args[2]
73+
74+
input_slice_size = weight_node.data.shape[1]
75+
output_slice_size = weight_node.data.shape[0] // groups
76+
77+
no_q_dq_meta = copy(meta)
78+
no_q_dq_meta.data = {}
79+
no_q_dq_meta.data = {}
80+
81+
slice_op, conv_op, cat_op = DecomposeGroupedConv._get_decomposition(op)
82+
83+
input_slices = []
84+
for i in range(groups):
85+
start_index = i * input_slice_size
86+
stop_index = (i + 1) * input_slice_size
87+
slice_args = (input_node, 1, start_index, stop_index)
88+
89+
input_slices.append(
90+
super().call_operator(slice_op, slice_args, kwargs, no_q_dq_meta)
91+
)
92+
93+
filter_slices = []
94+
for i in range(groups):
95+
start_index = i * output_slice_size
96+
stop_index = (i + 1) * output_slice_size
97+
slice_args = (weight_node, 0, start_index, stop_index)
98+
99+
filter_slices.append(
100+
super().call_operator(slice_op, slice_args, kwargs, no_q_dq_meta)
101+
)
102+
103+
bias_slices = []
104+
for i in range(groups):
105+
if bias_node is None:
106+
bias_slices.append(None)
107+
else:
108+
109+
start_index = i * output_slice_size
110+
stop_index = (i + 1) * output_slice_size
111+
slice_args = (bias_node, 0, start_index, stop_index)
112+
113+
bias_slices.append(
114+
super().call_operator(slice_op, slice_args, kwargs, no_q_dq_meta)
115+
)
116+
117+
output_slices = []
118+
for input_slice, filter_slice, bias_slice in zip(
119+
input_slices, filter_slices, bias_slices
120+
):
121+
122+
if op == exir_ops.edge.aten.convolution.default:
123+
conv_args = (input_slice, filter_slice, bias_slice, *args[3:8], 1)
124+
elif op == torch.ops.aten.conv2d.default:
125+
conv_args = (input_slice, filter_slice, bias_slice, *args[3:6], 1)
126+
else:
127+
raise RuntimeError("Unvalid op for grouped conv decomposition.")
128+
129+
output_slices.append(
130+
super().call_operator(conv_op, conv_args, kwargs, meta)
131+
)
132+
133+
cat_args = (output_slices, 1)
134+
return super().call_operator(cat_op, cat_args, kwargs, no_q_dq_meta)

backends/arm/_passes/decompose_maxpool2d_with_dilation.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def call_operator(self, op, args, kwargs, meta):
3636
stride = args[2]
3737
padding = args[3] if len(args) >= 4 else 0
3838
dilation = args[4] if len(args) >= 5 else 1
39+
ceil_mode = args[5] if len(args) == 6 else False
3940

4041
# Normalize attributes
4142
pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding
@@ -45,12 +46,9 @@ def call_operator(self, op, args, kwargs, meta):
4546
)
4647
s_h, s_w = (stride, stride) if isinstance(stride, int) else stride
4748

48-
# If no dilation: call EXIR edge op with only supported args (x, kernel, stride[, padding])
49+
# If no dilation: call EXIR edge op
4950
if d_h == 1 and d_w == 1:
50-
minimal_args = [x, kernel_size, stride]
51-
# only include padding if non-zero
52-
if (pad_h, pad_w) != (0, 0):
53-
minimal_args.append((pad_h, pad_w))
51+
minimal_args = [x, kernel_size, stride, padding, dilation, ceil_mode]
5452
return super().call_operator(op, tuple(minimal_args), {}, meta)
5553

5654
# Compute padded and packed dimensions for dilation > 1
@@ -102,7 +100,7 @@ def call_operator(self, op, args, kwargs, meta):
102100
if is_with_indices
103101
else exir_ops.edge.aten.max_pool2d.default
104102
)
105-
pool_args = (x2, (k_h, k_w), (s_h, s_w), (0, 0))
103+
pool_args = (x2, (k_h, k_w), (s_h, s_w), (0, 0), 1, ceil_mode)
106104
pool_out = super().call_operator(
107105
pool_edge_op,
108106
pool_args,

0 commit comments

Comments
 (0)