Skip to content

Commit 9582186

Browse files
committed
Refactor TOSAOperatorSupport to allow for maping against TosaSpecification
Change-Id: Ib7eb502861948c2f8dee995f28cbb7f2baa00afb
1 parent ff28b0d commit 9582186

File tree

5 files changed

+206
-74
lines changed

5 files changed

+206
-74
lines changed

backends/arm/arm_partitioner.py

Lines changed: 4 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
# pyre-unsafe
77

88
import logging
9-
import operator
109
import os
11-
from typing import Callable, cast, final, List, Optional, Tuple
10+
from typing import Callable, final, List, Optional, Tuple
1211

1312
import torch
1413
from executorch.backends.arm.arm_backend import ArmBackend # usort: skip
1514
from executorch.backends.arm._passes.tag_io_quant_pass import TagIOQuantPass
15+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
16+
TOSASupportedOperators,
17+
)
1618
from executorch.backends.arm.tosa_specification import TosaSpecification
1719
from executorch.exir.backend.compile_spec_schema import CompileSpec
1820
from executorch.exir.backend.partitioner import (
@@ -21,13 +23,10 @@
2123
PartitionResult,
2224
)
2325
from executorch.exir.backend.utils import tag_constant_data
24-
from executorch.exir.dialects._ops import ops as exir_ops
2526
from executorch.exir.passes import PassManager
2627
from torch.export.exported_program import ExportedProgram
2728
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
2829

29-
from torch.fx.passes.operator_support import OperatorSupportBase
30-
3130
logger = logging.getLogger(__name__)
3231
logger.setLevel(logging.WARNING)
3332
TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1"
@@ -36,75 +35,6 @@
3635
logger.setLevel(logging.INFO)
3736

3837

39-
class TOSASupportedOperators(OperatorSupportBase):
40-
def __init__(self, tosa_spec: TosaSpecification):
41-
super().__init__()
42-
self.tosa_spec = tosa_spec
43-
44-
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
45-
supported = node.op == "call_function" and node.target in [
46-
exir_ops.edge.aten.add.Tensor,
47-
exir_ops.edge.aten.expand_copy.default,
48-
exir_ops.edge.aten.cat.default,
49-
exir_ops.edge.aten.bmm.default,
50-
exir_ops.edge.aten.permute_copy.default,
51-
exir_ops.edge.aten.hardtanh.default,
52-
exir_ops.edge.aten.convolution.default,
53-
exir_ops.edge.aten.div.Tensor,
54-
exir_ops.edge.aten.exp.default,
55-
exir_ops.edge.aten.log.default,
56-
exir_ops.edge.aten.linear.default,
57-
exir_ops.edge.aten.split_with_sizes_copy.default,
58-
exir_ops.edge.aten.full.default,
59-
exir_ops.edge.aten.mul.Tensor,
60-
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
61-
exir_ops.edge.aten.native_layer_norm.default,
62-
exir_ops.edge.aten.avg_pool2d.default,
63-
exir_ops.edge.aten.max_pool2d_with_indices.default,
64-
exir_ops.edge.aten.sigmoid.default,
65-
exir_ops.edge.aten.mm.default,
66-
exir_ops.edge.aten.repeat.default,
67-
exir_ops.edge.aten.reciprocal.default,
68-
exir_ops.edge.aten.relu.default,
69-
exir_ops.edge.aten.rsqrt.default,
70-
exir_ops.edge.aten._softmax.default,
71-
exir_ops.edge.aten.select_copy.int,
72-
exir_ops.edge.aten._log_softmax.default,
73-
exir_ops.edge.aten.slice_copy.Tensor,
74-
exir_ops.edge.aten.sub.Tensor,
75-
exir_ops.edge.aten.sum.dim_IntList,
76-
exir_ops.edge.aten.tanh.default,
77-
exir_ops.edge.aten.upsample_nearest2d.vec,
78-
exir_ops.edge.aten.view_copy.default,
79-
exir_ops.edge.aten.clone.default,
80-
exir_ops.edge.aten.mean.dim,
81-
exir_ops.edge.aten.var.correction,
82-
exir_ops.edge.aten.unsqueeze_copy.default,
83-
exir_ops.edge.aten.squeeze_copy.dims,
84-
operator.getitem,
85-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
86-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
87-
]
88-
89-
supported &= self.is_node_supported_custom(node)
90-
91-
# Override partitioning based on pre partition passes
92-
if "arm_override_partition" in node.meta:
93-
supported = supported & node.meta["arm_override_partition"]
94-
node.meta.pop("arm_override_partition")
95-
96-
return supported
97-
98-
def is_node_supported_custom(self, node: torch.fx.Node) -> bool:
99-
if node.target == exir_ops.edge.aten.mean.dim:
100-
keep_dim = node.args[2] if len(node.args) > 2 else False
101-
return cast(bool, keep_dim)
102-
if node.target == exir_ops.edge.aten.var.correction:
103-
keep_dim = node.kwargs.get("keepdim", False)
104-
return cast(bool, keep_dim)
105-
return True
106-
107-
10838
@final
10939
class ArmPartitioner(Partitioner):
11040
def __init__(self, compile_spec: List[CompileSpec]) -> None:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from . import mean_dim_support, tosa_supported_operators, var_correction_support # noqa
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import cast
9+
10+
import torch.fx as fx
11+
12+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
13+
register_tosa_support_check,
14+
SupportedTOSAOperatorCheck,
15+
)
16+
from executorch.backends.arm.tosa_specification import TosaSpecification
17+
from executorch.exir.dialects._ops import ops as exir_ops
18+
19+
20+
@register_tosa_support_check
21+
class MeanDimSupported(SupportedTOSAOperatorCheck):
22+
targets = [exir_ops.edge.aten.mean.dim]
23+
24+
tosa_specs = [
25+
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
26+
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
27+
]
28+
29+
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool:
30+
assert node.target in self.targets
31+
32+
keep_dim = node.args[2] if len(node.args) > 2 else False
33+
return cast(bool, keep_dim)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import operator
9+
10+
import torch.fx as fx
11+
from executorch.backends.arm.tosa_specification import TosaSpecification
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from torch.fx.passes.operator_support import OperatorSupportBase
14+
15+
16+
class SupportedTOSAOperatorCheck:
17+
"""
18+
Supported OP for TOSA lowering
19+
"""
20+
21+
# Should be populated by subclass implementation
22+
tosa_specs: list[TosaSpecification] = []
23+
targets: list[str] = []
24+
25+
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool:
26+
"""
27+
Checks if the fx.Node node is lowerable using the TOSA specification defined by tosa_spec.
28+
To be implemented by subclasses targeting
29+
"""
30+
raise NotImplementedError("NodeVisitor must be extended.")
31+
32+
33+
# container for all SupportedTosaOperatorCheck classes
34+
_tosa_spec_dicts: dict[TosaSpecification, dict[str, SupportedTOSAOperatorCheck]] = {
35+
TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {},
36+
TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {},
37+
}
38+
39+
40+
def register_tosa_support_check(checker):
41+
"""
42+
Decorator to mark a subclass implmentation of SupportedTosaOperatorCheck
43+
to be registered for checking if a torch.fx.Node is lowerable given
44+
a TOSA specification.
45+
"""
46+
for tosa_spec in checker.tosa_specs:
47+
for target in checker.targets:
48+
_tosa_spec_dicts[tosa_spec][target] = checker
49+
return checker
50+
51+
52+
def get_registered_tosa_support_checks(
53+
tosa_spec: TosaSpecification,
54+
) -> dict[str, SupportedTOSAOperatorCheck]:
55+
56+
if tosa_spec not in _tosa_spec_dicts:
57+
raise RuntimeError
58+
59+
tosa_support_checks = {}
60+
for target, tosa_check in _tosa_spec_dicts[tosa_spec].items():
61+
tosa_support_checks[target] = tosa_check()
62+
63+
return tosa_support_checks
64+
65+
66+
class TOSASupportedOperators(OperatorSupportBase):
67+
def __init__(self, tosa_spec: TosaSpecification):
68+
super().__init__()
69+
self.tosa_spec = tosa_spec
70+
71+
def is_node_supported(self, submodules, node: fx.Node) -> bool:
72+
supported = node.op == "call_function" and node.target in [
73+
exir_ops.edge.aten.add.Tensor,
74+
exir_ops.edge.aten.expand_copy.default,
75+
exir_ops.edge.aten.cat.default,
76+
exir_ops.edge.aten.bmm.default,
77+
exir_ops.edge.aten.permute_copy.default,
78+
exir_ops.edge.aten.hardtanh.default,
79+
exir_ops.edge.aten.convolution.default,
80+
exir_ops.edge.aten.div.Tensor,
81+
exir_ops.edge.aten.exp.default,
82+
exir_ops.edge.aten.log.default,
83+
exir_ops.edge.aten.linear.default,
84+
exir_ops.edge.aten.split_with_sizes_copy.default,
85+
exir_ops.edge.aten.full.default,
86+
exir_ops.edge.aten.mul.Tensor,
87+
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
88+
exir_ops.edge.aten.native_layer_norm.default,
89+
exir_ops.edge.aten.avg_pool2d.default,
90+
exir_ops.edge.aten.max_pool2d_with_indices.default,
91+
exir_ops.edge.aten.sigmoid.default,
92+
exir_ops.edge.aten.mm.default,
93+
exir_ops.edge.aten.repeat.default,
94+
exir_ops.edge.aten.reciprocal.default,
95+
exir_ops.edge.aten.relu.default,
96+
exir_ops.edge.aten.rsqrt.default,
97+
exir_ops.edge.aten._softmax.default,
98+
exir_ops.edge.aten.select_copy.int,
99+
exir_ops.edge.aten._log_softmax.default,
100+
exir_ops.edge.aten.slice_copy.Tensor,
101+
exir_ops.edge.aten.sub.Tensor,
102+
exir_ops.edge.aten.sum.dim_IntList,
103+
exir_ops.edge.aten.tanh.default,
104+
exir_ops.edge.aten.upsample_nearest2d.vec,
105+
exir_ops.edge.aten.view_copy.default,
106+
exir_ops.edge.aten.clone.default,
107+
exir_ops.edge.aten.unsqueeze_copy.default,
108+
exir_ops.edge.aten.squeeze_copy.dims,
109+
operator.getitem,
110+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
111+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
112+
]
113+
114+
if not supported:
115+
supported = self.is_node_supported_custom(node)
116+
117+
# Override partitioning based on pre partition passes
118+
if "arm_override_partition" in node.meta:
119+
supported = supported & node.meta["arm_override_partition"]
120+
node.meta.pop("arm_override_partition")
121+
122+
return supported
123+
124+
def is_node_supported_custom(self, node: fx.Node) -> bool:
125+
tosa_checks = get_registered_tosa_support_checks(self.tosa_spec)
126+
if node.target in tosa_checks.keys():
127+
return tosa_checks[node.target].is_node_supported(node, self.tosa_spec)
128+
return False
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import cast
9+
10+
import torch.fx as fx
11+
12+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
13+
register_tosa_support_check,
14+
SupportedTOSAOperatorCheck,
15+
)
16+
from executorch.backends.arm.tosa_specification import TosaSpecification
17+
from executorch.exir.dialects._ops import ops as exir_ops
18+
19+
20+
@register_tosa_support_check
21+
class VarCorrectionSupported(SupportedTOSAOperatorCheck):
22+
targets = [exir_ops.edge.aten.var.correction]
23+
24+
tosa_specs = [
25+
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
26+
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
27+
]
28+
29+
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool:
30+
assert node.target in self.targets
31+
32+
keep_dim = node.kwargs.get("keepdim", False)
33+
return cast(bool, keep_dim)

0 commit comments

Comments
 (0)