Skip to content

Commit 2b91eba

Browse files
tarun292facebook-github-bot
authored andcommitted
to_edge_transform_and_lower (#3483)
Summary: Pull Request resolved: #3483 This diff introduces the to_edge_transform_and_lower API. The changes introduces are: - Adding support to the Parititioner class to register ops that it doesn't want to be composed - Changes to _program.py to add the implementation of to_edge_transform_and_lower() - Added a basic test case to test that Linear, SDPA & Linear + SDPA are not decomposed when asked and the corresponding backend consumes them. Reviewed By: kimishpatel, mcr229 Differential Revision: D56401086 fbshipit-source-id: 04262a58fc70e8191df33b4342295e56a5baf354
1 parent 636c5c3 commit 2b91eba

File tree

7 files changed

+525
-15
lines changed

7 files changed

+525
-15
lines changed

exir/backend/partitioner.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from abc import ABC, abstractmethod
88
from dataclasses import dataclass
99
from types import MappingProxyType
10-
from typing import Dict, List, Mapping, NamedTuple, Union
10+
from typing import Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union
11+
12+
import torch
1113

1214
from executorch.exir.backend.backend_details import enforcedmethod
1315
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -91,3 +93,21 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
9193
PartitionResult: includes the tagged graph and the delegation spec to indicate what backend_id and compile_spec is used for each node and the tag created by the backend developers.
9294
"""
9395
pass
96+
97+
def ops_to_not_decompose(
98+
self,
99+
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
100+
"""
101+
Returns a list of operator names that should not be decomposed. When these ops are
102+
registered and the `to_backend` is invoked through to_edge_transform_and_lower it will be
103+
guaranteed that the program that the backend receives will not have any of these ops
104+
decomposed.
105+
106+
Returns:
107+
List[torch._ops.OpOverload]: a list of operator names that should not be decomposed.
108+
Optional[Callable[[torch.fx.Node], bool]]]: an optional callable, acting as a filter, that users can provide
109+
which will be called for each node in the graph that users can use as a filter for certain
110+
nodes that should be continued to be decomposed even though the op they correspond to is
111+
in the list returned by ops_to_not_decompose.
112+
"""
113+
return ([], None)

exir/backend/test/backend_with_compiler_demo.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,19 @@ def preprocess(
8383
processed_bytes = ""
8484
number_of_instruction = 0
8585
debug_handle_map = {}
86+
match_ops = [
87+
exir_ops.edge.aten.sin.default,
88+
exir_ops.edge.aten.mm.default,
89+
exir_ops.edge.aten.add.Tensor,
90+
torch.ops.aten.sin.default,
91+
exir_ops.edge.aten.linear.default,
92+
exir_ops.edge.aten.scaled_dot_product_attention.default,
93+
]
94+
8695
for node in edge_program.graph.nodes:
8796
if node.op == "call_function":
8897
# TODO(gasoonjia): remove the support of torch.ops.aten.sin.default after migrate serde to edge dialect.
89-
if (
90-
node.target == exir_ops.edge.aten.sin.default
91-
or node.target == exir_ops.edge.aten.mm.default
92-
or node.target == exir_ops.edge.aten.add.Tensor
93-
or node.target == torch.ops.aten.sin.default
94-
):
98+
if node.target in match_ops:
9599
simple_op = DemoOp(
96100
node.target.__name__,
97101
int(torch.prod(torch.tensor(node.meta["val"].shape), 0).item()),

exir/backend/test/op_partitioner_demo.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Dict, final
7+
from typing import Callable, Dict, final, List, Optional, Tuple
88

99
import torch
1010
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
@@ -71,6 +71,7 @@ def _partition_graph_module(
7171
for _, submodule, _ in get_control_flow_submodules(graph_module):
7272
ret_partition_tags = self._partition_graph_module(submodule)
7373
partition_tags.update(ret_partition_tags)
74+
7475
return partition_tags
7576

7677
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
@@ -121,3 +122,74 @@ def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult:
121122
return PartitionResult(
122123
tagged_exported_program=edge_exported_program, partition_tags=partition_tags
123124
)
125+
126+
127+
ops_not_to_decompose = [
128+
torch.ops.aten.linear.default,
129+
torch.ops.aten.scaled_dot_product_attention.default,
130+
]
131+
132+
edge_ops_non_decomposed = [
133+
exir_ops.edge.aten.linear.default,
134+
exir_ops.edge.aten.scaled_dot_product_attention.default,
135+
]
136+
137+
138+
class OpsToNotDecomposeOperatorSupport(OperatorSupportBase):
139+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
140+
return node.op == "call_function" and node.target in edge_ops_non_decomposed
141+
142+
143+
@final
144+
class NonDecompTestPartitioner(Partitioner):
145+
"""
146+
Partitions all add/mul nodes regardless of order
147+
"""
148+
149+
def __init__(self) -> None:
150+
self.op_support = any_chain(OpsToNotDecomposeOperatorSupport())
151+
self.delegation_spec = DelegationSpec(
152+
BackendWithCompilerDemo.__name__,
153+
[CompileSpec("max_value", bytes([4]))],
154+
)
155+
156+
def ops_to_not_decompose(
157+
self,
158+
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
159+
def filter_ops(node: torch.fx.Node) -> bool:
160+
if node.op == "call_function" and node.target in ops_not_to_decompose:
161+
if len(node.args) == 3:
162+
# This means that linear has a bias which is the only linear we support in this
163+
# demo partitioner.
164+
return True
165+
else:
166+
return False
167+
168+
return True
169+
170+
return (ops_not_to_decompose, filter_ops)
171+
172+
def _partition_graph_module(
173+
self,
174+
graph_module: torch.fx.GraphModule,
175+
) -> Dict[str, DelegationSpec]:
176+
partition_tags: Dict[str, DelegationSpec] = {}
177+
partition_list = generate_pattern_op_partitions(
178+
graph_module, op_support=self.op_support
179+
)
180+
for partition in partition_list:
181+
for node in partition.nodes:
182+
delegation_tag = f"tag{partition.id}"
183+
node.meta["delegation_tag"] = delegation_tag
184+
partition_tags[delegation_tag] = self.delegation_spec
185+
186+
for _, submodule, _ in get_control_flow_submodules(graph_module):
187+
ret_partition_tags = self._partition_graph_module(submodule)
188+
partition_tags.update(ret_partition_tags)
189+
return partition_tags
190+
191+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
192+
partition_tags = self._partition_graph_module(exported_program.graph_module)
193+
return PartitionResult(
194+
tagged_exported_program=exported_program, partition_tags=partition_tags
195+
)

exir/program/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ python_library(
2121
deps = [
2222
"//caffe2:torch",
2323
"//executorch/exir:error",
24+
"//executorch/exir:graph_module",
2425
"//executorch/exir:pass_manager",
2526
"//executorch/exir:print_program",
2627
"//executorch/exir:schema",
@@ -36,6 +37,7 @@ python_library(
3637
"//executorch/exir/passes:normalize_view_copy_base_pass",
3738
"//executorch/exir/passes:remove_graph_asserts_pass",
3839
"//executorch/exir/passes:remove_mixed_type_operators",
40+
"//executorch/exir/passes:replace_aten_with_edge_pass",
3941
"//executorch/exir/passes:replace_view_copy_with_view_pass",
4042
"//executorch/exir/passes:spec_prop_pass",
4143
"//executorch/exir/verification:verifier",

exir/program/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from executorch.exir.program._fake_program import get_fake_program
1010
from executorch.exir.program._program import (
1111
_to_edge,
12+
_to_edge_transform_and_lower,
1213
edge_to_executorch_passes,
1314
EdgeProgramManager,
1415
ExecutorchProgram,
@@ -22,6 +23,7 @@
2223
"ExecutorchProgram",
2324
"_to_edge",
2425
"to_edge",
26+
"_to_edge_transform_and_lower",
2527
"edge_to_executorch_passes",
2628
"EdgeProgramManager",
2729
"ExecutorchProgramManager",

0 commit comments

Comments
 (0)