Skip to content

Commit 8ec08f9

Browse files
authored
Arm backend: Add is_node_supported checks for 4 ops (#8209)
Remove unnecessary asserts from op_sigmoid and op_log Add is_node_supported checks for 4 ops For convolution, maxpool2d, avgpool2d, and sum. The checks mostly target hardware constraints on Ethos-U55, though convolution also checks for some unsupported behavior. Signed-off-by: Erik Lundell <[email protected]>
1 parent 883ff14 commit 8ec08f9

File tree

11 files changed

+389
-23
lines changed

11 files changed

+389
-23
lines changed
Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
77

8-
from . import right_shift_support, to_copy_support, tosa_supported_operators # noqa
8+
from . import ( # noqa
9+
convolution_support,
10+
pool_2d_support,
11+
reduce_sum_support,
12+
right_shift_support,
13+
to_copy_support,
14+
tosa_supported_operators,
15+
)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import cast
7+
8+
import torch
9+
import torch.fx as fx
10+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
11+
register_tosa_support_check,
12+
SupportedTOSAOperatorCheck,
13+
)
14+
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
18+
@register_tosa_support_check
19+
class ConvolutionSupported(SupportedTOSAOperatorCheck):
20+
targets = [exir_ops.edge.aten.convolution.default]
21+
22+
tosa_specs = [
23+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
24+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
25+
]
26+
27+
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
28+
29+
# Not implemented
30+
transposed = cast(bool, node.args[6])
31+
output_padding = cast(list[int], node.args[7])
32+
if transposed:
33+
return False
34+
35+
for pad in output_padding:
36+
if pad != 0:
37+
return False
38+
39+
# Hardware specific constraints
40+
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
41+
return True
42+
else:
43+
return self._is_node_supported_u55(node)
44+
45+
def _is_node_supported_u55(self, node: fx.Node):
46+
"""Hardware constraints for Ethos-U-55 case, Vela 4.2.0 (25.02 release)"""
47+
48+
shape_in = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape
49+
shape_out = node.meta["val"].shape
50+
kernel = cast(fx.Node, node.args[1]).meta["val"].shape
51+
group = cast(int, node.args[8])
52+
53+
C_in = shape_in[1]
54+
C_out = shape_out[1]
55+
if (C_in == group) and (C_out % C_in) == 0:
56+
# Depthwise convolution
57+
for dim in shape_in[1:]:
58+
if not 1 <= dim <= 65536:
59+
return False
60+
else:
61+
# Convolution
62+
if not 1 <= C_in <= 65536:
63+
return False
64+
65+
kernel_w = kernel[2]
66+
kernel_h = kernel[3] if len(kernel) > 3 else 1
67+
# Kernel condition misses constraint on sum of absolute weights
68+
if not 1 <= kernel_h <= 64 or not 1 <= kernel_w * kernel_h <= 4096:
69+
return False
70+
71+
if not self._stride_condition(node):
72+
return False
73+
74+
return True
75+
76+
def _stride_condition(self, node: fx.Node) -> bool:
77+
"""This condition is somewhat complex but boils down
78+
to not supporting stride > 3, unless we have some special conditions.
79+
This condition is a simplified, relaxed version of the hardware constraint,
80+
since the actual constraint requires information not available
81+
here (without a lot of work).
82+
83+
This means that we might accept ops that are not actually supported.
84+
"""
85+
strides = cast(list[int], node.args[3])
86+
has_padding = any(pad > 0 for pad in cast(list[int], node.args[4]))
87+
dilations = cast(list[int], node.args[5])
88+
if len(dilations) == 1:
89+
dilations = [dilations[0]] * 2
90+
if len(strides) == 1:
91+
strides = [strides[0]] * 2
92+
93+
for stride, dilation in zip(strides, dilations):
94+
stride_condition = 1 <= stride <= 3
95+
dilation_condition = (not has_padding) and (dilation == 1)
96+
if (not stride_condition) and (not dilation_condition):
97+
return False
98+
99+
return True
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import cast
7+
8+
import torch
9+
import torch.fx as fx
10+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
11+
register_tosa_support_check,
12+
SupportedTOSAOperatorCheck,
13+
)
14+
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
18+
def kernel_check(kernel: tuple[int, int]) -> bool:
19+
if not (1 <= kernel[0] * kernel[1] <= 65536):
20+
return False
21+
return 1 <= kernel[1] <= 256
22+
23+
24+
def stride_check(strides: tuple[int, int]) -> bool:
25+
return all(1 <= stride <= 3 for stride in strides)
26+
27+
28+
def dim_check(shape=torch.Size) -> bool:
29+
check = shape[0] == 1
30+
for dim in shape:
31+
check &= 1 <= dim <= 65536
32+
return check
33+
34+
35+
@register_tosa_support_check
36+
class AvgPool2dSupported(SupportedTOSAOperatorCheck):
37+
targets = [
38+
exir_ops.edge.aten.avg_pool2d.default,
39+
]
40+
41+
tosa_specs = [
42+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
43+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
44+
]
45+
46+
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
47+
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
48+
return True
49+
50+
# U55 case, Vela 4.2.0 (25.02 release)
51+
shape = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape
52+
kernel = cast(tuple[int, int], node.args[1])
53+
stride = cast(tuple[int, int], node.args[2])
54+
if len(node.args) > 3:
55+
# Padding case
56+
if not all(1 <= k <= 8 for k in kernel):
57+
return False
58+
else:
59+
if not kernel_check(kernel):
60+
return False
61+
62+
return dim_check(shape) and stride_check(stride)
63+
64+
65+
@register_tosa_support_check
66+
class MaxPool2dSupported(SupportedTOSAOperatorCheck):
67+
targets = [
68+
exir_ops.edge.aten.max_pool2d_with_indices.default,
69+
]
70+
71+
tosa_specs = [
72+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
73+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
74+
]
75+
76+
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
77+
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
78+
return True
79+
80+
# U55 case, Vela 4.2.0 (25.02 release)
81+
shape = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape
82+
kernel = cast(tuple[int, int], node.args[1])
83+
stride = cast(tuple[int, int], node.args[2])
84+
85+
return kernel_check(kernel) and dim_check(shape) and stride_check(stride)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import cast
7+
8+
import torch.fx as fx
9+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
10+
register_tosa_support_check,
11+
SupportedTOSAOperatorCheck,
12+
)
13+
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
16+
17+
@register_tosa_support_check
18+
class SumSupported(SupportedTOSAOperatorCheck):
19+
targets = [exir_ops.edge.aten.sum.dim_IntList]
20+
21+
tosa_specs = [
22+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
23+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
24+
]
25+
26+
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
27+
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
28+
return True
29+
30+
# U55 case, Vela 4.2.0 (25.02 release)
31+
input_shape = node.all_input_nodes[0].meta["val"].shape
32+
dim_list = cast(list[int], node.args[1])
33+
dim_list = [dim % len(input_shape) for dim in dim_list]
34+
35+
for dim in dim_list:
36+
if not 1 <= input_shape[dim] <= 65536:
37+
return False
38+
39+
# We can't be certain of which dim is the last in memory yet,
40+
# Always go for stricter condition.
41+
pre_R_product = 1.0
42+
for length in input_shape[:dim]:
43+
pre_R_product *= length
44+
post_R_product = 1.0
45+
for length in input_shape[dim + 1 :]:
46+
post_R_product *= length
47+
if not 1 <= pre_R_product <= 65536:
48+
return False
49+
if not 1 <= post_R_product <= 65536:
50+
return False
51+
return True

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
8282
exir_ops.edge.aten.hardsigmoid.default,
8383
exir_ops.edge.aten.hardtanh.default,
8484
exir_ops.edge.aten.hardswish.default,
85-
exir_ops.edge.aten.convolution.default,
8685
exir_ops.edge.aten.div.Tensor,
8786
exir_ops.edge.aten.eq.Tensor,
8887
exir_ops.edge.aten.exp.default,
@@ -97,8 +96,6 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
9796
exir_ops.edge.aten.mul.Tensor,
9897
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
9998
exir_ops.edge.aten.native_layer_norm.default,
100-
exir_ops.edge.aten.avg_pool2d.default,
101-
exir_ops.edge.aten.max_pool2d_with_indices.default,
10299
exir_ops.edge.aten.sigmoid.default,
103100
exir_ops.edge.aten.mean.dim,
104101
exir_ops.edge.aten.mm.default,
@@ -113,7 +110,6 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
113110
exir_ops.edge.aten._log_softmax.default,
114111
exir_ops.edge.aten.slice_copy.Tensor,
115112
exir_ops.edge.aten.sub.Tensor,
116-
exir_ops.edge.aten.sum.dim_IntList,
117113
exir_ops.edge.aten.tanh.default,
118114
exir_ops.edge.aten.upsample_nearest2d.vec,
119115
exir_ops.edge.aten.var.correction,

backends/arm/operators/op_log.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def define_node(
3636
output: TosaArg,
3737
) -> None:
3838
assert len(node.all_input_nodes) == 1
39-
assert len(node.users) == 1
4039
assert inputs[0].dtype == output.dtype == ts.DType.FP32
4140

4241
tosa_graph.addOperator(TosaOp.Op().LOG, [inputs[0].name], [output.name])

backends/arm/operators/op_sigmoid.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def define_node(
3737
) -> None:
3838

3939
assert len(node.all_input_nodes) == 1
40-
assert len(node.users) == 1
4140
assert inputs[0].dtype == output.dtype == ts.DType.FP32
4241

4342
tosa_graph.addOperator(TosaOp.Op().SIGMOID, [inputs[0].name], [output.name])

backends/arm/test/ops/test_avg_pool.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,36 @@ def test_avgpool2d_tosa_u85_BI(
172172
common.get_u85_compile_spec(),
173173
(test_data,),
174174
)
175+
176+
reject_data_suite = [
177+
(AvgPool2d(1, 1, 0), torch.rand(2, 5, 5, 5)),
178+
(AvgPool2d((2, 9), 1, 1), torch.rand(1, 16, 5, 32)),
179+
(AvgPool2d(1, 4, 0), torch.rand(1, 10, 10, 10)),
180+
(AvgPool2d((1, 257), 1, 0), torch.rand(1, 16, 5, 300)),
181+
(AvgPool2d((800, 90), 1, 0), torch.rand(1, 16, 850, 100)),
182+
]
183+
184+
@parameterized.expand(reject_data_suite)
185+
def test_reject_avgpool2d_u55_BI(
186+
self,
187+
module: torch.nn.Module,
188+
test_data: torch.tensor,
189+
):
190+
compile_spec = common.get_u55_compile_spec()
191+
tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec)
192+
quantizer = ArmQuantizer(tosa_spec).set_io(get_symmetric_quantization_config())
193+
194+
(
195+
ArmTester(
196+
module,
197+
example_inputs=(test_data,),
198+
compile_spec=compile_spec,
199+
)
200+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
201+
.export()
202+
.check_count({"torch.ops.aten.avg_pool2d.default": 1})
203+
.check(["torch.ops.quantized_decomposed"])
204+
.to_edge_transform_and_lower()
205+
.check(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"])
206+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
207+
)

backends/arm/test/ops/test_conv2d.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torch
1111
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1213
from executorch.backends.arm.test.tester.test_pipeline import (
1314
EthosU55PipelineBI,
1415
EthosU85PipelineBI,
@@ -406,3 +407,57 @@ def test_conv2d_u85_BI_on_fvp(test_module):
406407
test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=True
407408
)
408409
pipeline.run()
410+
411+
412+
reject_suite = {
413+
"large_stride": Conv2d(
414+
in_channels=1,
415+
out_channels=1,
416+
kernel_size=(2, 4),
417+
stride=(2, 4),
418+
padding=1,
419+
width=10,
420+
height=14,
421+
batches=1,
422+
),
423+
"large_kernel_height": Conv2d(
424+
in_channels=1,
425+
out_channels=1,
426+
kernel_size=(2, 65),
427+
stride=(1, 1),
428+
padding=0,
429+
width=70,
430+
height=70,
431+
batches=1,
432+
),
433+
"large_kernel": Conv2d(
434+
in_channels=1,
435+
out_channels=1,
436+
kernel_size=(70, 60),
437+
stride=(1,),
438+
padding=0,
439+
width=80,
440+
height=80,
441+
batches=1,
442+
),
443+
}
444+
445+
446+
@common.parametrize("module", reject_suite)
447+
def test_reject_conv2d_u55_BI(
448+
module: Conv2d,
449+
):
450+
(
451+
ArmTester(
452+
module,
453+
example_inputs=module.get_inputs(),
454+
compile_spec=common.get_u55_compile_spec(),
455+
)
456+
.quantize()
457+
.export()
458+
.check_count({"torch.ops.aten.conv2d.default": 1})
459+
.check(["torch.ops.quantized_decomposed"])
460+
.to_edge_transform_and_lower()
461+
.check(["executorch_exir_dialects_edge__ops_aten_convolution_default"])
462+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
463+
)

0 commit comments

Comments
 (0)