Skip to content

Commit c3357e1

Browse files
mcr229facebook-github-bot
authored andcommitted
use _preserve_ops for to_edge_transform_and_lower (#4273)
Summary: Pull Request resolved: #4273 ## Motivation `run_decompositions()` has a new preserve_ops functionality which allows us to specify which ops we want to refrain from decomposing. This is super helpful for the to_edge_transform_and_lower api because it allows us to preserve decomposition that occur beyond the first level. For example consider LSTM. when exported using torch.export, we see a torch.ops.aten.LSTM() operator in the graph. When running decompositions this is decomposed into linear, and then further decomposed into addmm. Since the linear op is produced from decomposing LSTM and does not exist until after we run_decompositions(), we can not perform our trick of changing the name space to prevent its decomposition. However, now using `_preserve_ops=(torch.ops.aten.linear.default,)` we are able to prevent this second layer decomposition. ## API Implementation Change So in the implementation we do two passes. The first pass is we run_decompositions preserving all aten ops specified by our partitioners using `_preserve_ops`. On our second pass, we further filter which aten ops should be preserved by using the check_op_fn given to us by partitioners. We then use our namespace trick to prevent the decomposition of all aten ops which pass our check_op_fn. ## Testing Changes To strengthen our tests, I first change the functionality of the NonDecompPartitioner. We partition only pre-decomp aten ops. And each of these ops live within their own delegate (this allows us to have a 1:1 mapping for call_delegate and pre_decomp aten nodes). In testing, this will allow us to ensure that the number of ops which are to preserved is correct by counting the number of delegates calls. In testing we then count the number of aten ops which should correctly be preserved. And then check after the fact that all these ops are 1. No longer in the graph after to_edge_transform_and_lower 2. Each of these preserved ops are transformed into a call_delegate node Reviewed By: tarun292 Differential Revision: D59786323 fbshipit-source-id: 7ea946e0d5afc8ebddd26913f6e843305116ad3b
1 parent b448254 commit c3357e1

File tree

3 files changed

+148
-36
lines changed

3 files changed

+148
-36
lines changed

exir/backend/test/op_partitioner_demo.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import itertools
78
from typing import Callable, Dict, final, List, Optional, Tuple
89

910
import torch
@@ -24,6 +25,7 @@
2425
from executorch.exir.graph_module import get_control_flow_submodules
2526
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
2627
from torch.export import ExportedProgram
28+
from torch.fx.passes.infra.partitioner import Partition
2729
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
2830

2931

@@ -145,10 +147,12 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
145147
@final
146148
class NonDecompTestPartitioner(Partitioner):
147149
"""
148-
Partitions all add/mul nodes regardless of order
150+
Non Decomp Test Partitioner, preserves aten ops from decomposition for delegate
151+
consumption. Ensures that non_decomposed_edge_ops are all within their own delegate
149152
"""
150153

151154
def __init__(self) -> None:
155+
self.supported_non_decomposed_edge_ops = edge_ops_non_decomposed
152156
self.op_support = any_chain(OpsToNotDecomposeOperatorSupport())
153157
self.delegation_spec = DelegationSpec(
154158
BackendWithCompilerDemo.__name__,
@@ -171,14 +175,29 @@ def filter_ops(node: torch.fx.Node) -> bool:
171175

172176
return (ops_not_to_decompose, filter_ops)
173177

178+
def _generate_single_node_partition(
179+
self, gm: torch.fx.GraphModule
180+
) -> List[Partition]:
181+
partitions = []
182+
partition_id = itertools.count()
183+
nodes_seen = set()
184+
for node in gm.graph.nodes:
185+
if (
186+
node.op == "call_function"
187+
and node.target in self.supported_non_decomposed_edge_ops
188+
and node not in nodes_seen
189+
):
190+
partitions.append(Partition(nodes=[node], id=next(partition_id)))
191+
nodes_seen.add(node)
192+
193+
return partitions
194+
174195
def _partition_graph_module(
175196
self,
176197
graph_module: torch.fx.GraphModule,
177198
) -> Dict[str, DelegationSpec]:
178199
partition_tags: Dict[str, DelegationSpec] = {}
179-
partition_list = generate_pattern_op_partitions(
180-
graph_module, op_support=self.op_support
181-
)
200+
partition_list = self._generate_single_node_partition(graph_module)
182201
for partition in partition_list:
183202
for node in partition.nodes:
184203
delegation_tag = f"tag{partition.id}"

exir/program/_program.py

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,66 @@ def _sanity_check_graph_for_non_decomp_ops(
859859
logging.warning(warning_str)
860860

861861

862+
def _gen_edge_manager_for_partitioners(
863+
partitioner: Dict[str, List[Partitioner]],
864+
aten_programs: Dict[str, ExportedProgram],
865+
config: EdgeCompileConfig,
866+
constant_methods: Optional[Dict[str, Any]],
867+
) -> "EdgeProgramManager":
868+
"""
869+
Generates EdgeProgramManager for subsequent lowering to the
870+
partitioners specified by partitioner. The EdgeProgramManager is generated from
871+
aten_programs.
872+
873+
Partitioners specify what nodes should not be decomposed from the original aten programs.
874+
This is done through two passes of run_decompositions.
875+
- First pass preserves all aten_targets specified by partitioners to preserve
876+
them from nested decompositions
877+
- Second pass uses check_op fn provided by partitioners to perform additional checks
878+
on nodes with preserved aten targets. They are then replaces with transformed ops to
879+
keep them through the second pass of decompositions
880+
"""
881+
ops_set_to_not_decompose_by_program = {}
882+
edge_programs: Dict[str, ExportedProgram] = {}
883+
for name, program in aten_programs.items():
884+
if partitioner is not None:
885+
# preserve all ops listed by all partitioners first
886+
all_ops_no_decomp = set()
887+
for curr_partitioner in partitioner.get(name, []):
888+
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
889+
all_ops_no_decomp |= set(curr_ops_no_decomp)
890+
891+
program = program.run_decompositions(
892+
_default_decomposition_table(), _preserve_ops=tuple(all_ops_no_decomp)
893+
)
894+
# Among all the preserved aten ops, use the check_op_fn to do an additional
895+
# check on which ops need to be preserved and which ops need to be decomposed
896+
# Those which are truly preserved will be replaced with transformed ops
897+
ops_set_to_not_decompose_by_program[name] = (
898+
_replace_aten_ops_with_transformed_ops(name, program, partitioner)
899+
)
900+
program = program.run_decompositions(_default_decomposition_table())
901+
902+
_restore_transformed_ops_to_aten_ops(program)
903+
904+
edge_programs[name] = program
905+
906+
edge_programs[name] = _generate_edge_program(
907+
name,
908+
config,
909+
program,
910+
list(ops_set_to_not_decompose_by_program.get(name, [])),
911+
)
912+
913+
edge_manager = EdgeProgramManager(
914+
edge_programs,
915+
constant_methods,
916+
config,
917+
list(set().union(*ops_set_to_not_decompose_by_program.values())),
918+
)
919+
return edge_manager
920+
921+
862922
def _to_edge_transform_and_lower(
863923
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
864924
transform_passes: Optional[
@@ -909,8 +969,6 @@ def _to_edge_transform_and_lower(
909969
Returns:
910970
EdgeProgramManager
911971
"""
912-
ops_set_to_not_decompose = set()
913-
914972
assert not isinstance(constant_methods, EdgeCompileConfig)
915973
config = compile_config or EdgeCompileConfig()
916974
if not isinstance(programs, dict):
@@ -923,31 +981,8 @@ def _to_edge_transform_and_lower(
923981
else:
924982
partitioner = {}
925983

926-
ops_set_to_not_decompose_by_program = {}
927-
edge_programs: Dict[str, ExportedProgram] = {}
928-
for name, program in aten_programs.items():
929-
if partitioner is not None:
930-
ops_set_to_not_decompose_by_program[name] = (
931-
_replace_aten_ops_with_transformed_ops(name, program, partitioner)
932-
)
933-
program = program.run_decompositions(_default_decomposition_table())
934-
935-
_restore_transformed_ops_to_aten_ops(program)
936-
937-
edge_programs[name] = program
938-
939-
edge_programs[name] = _generate_edge_program(
940-
name,
941-
config,
942-
program,
943-
list(ops_set_to_not_decompose_by_program.get(name, [])),
944-
)
945-
946-
edge_manager = EdgeProgramManager(
947-
edge_programs,
948-
constant_methods,
949-
config,
950-
list(set().union(*ops_set_to_not_decompose_by_program.values())),
984+
edge_manager = _gen_edge_manager_for_partitioners(
985+
partitioner, aten_programs, config, constant_methods
951986
)
952987

953988
if transform_passes is not None:

exir/program/test/test_program.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pye-strict
88

9-
import operator
9+
import copy
1010
import unittest
1111
from typing import Any, Dict
1212

@@ -27,6 +27,7 @@
2727
ExecutorchProgramManager,
2828
to_edge,
2929
)
30+
from executorch.exir.tracer import _default_decomposition_table
3031
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
3132

3233
from executorch.extension.pybindings.portable_lib import (
@@ -102,6 +103,19 @@ def _get_random_inputs(cls):
102103
return (x,)
103104

104105

106+
class TestLSTM(torch.nn.Module):
107+
def __init__(self):
108+
super().__init__()
109+
self.lstm = torch.nn.LSTM(input_size=8, hidden_size=16, batch_first=True)
110+
111+
def forward(self, x):
112+
return self.lstm(x)
113+
114+
@classmethod
115+
def _get_random_inputs(cls):
116+
return (torch.rand(1, 10, 8),)
117+
118+
105119
class WrapperModule(torch.nn.Module):
106120
def __init__(self, fn):
107121
super().__init__()
@@ -550,23 +564,65 @@ def _use_foo_add(a: torch.Tensor, b: torch.Tensor):
550564
except SpecViolationError:
551565
self.fail("Should not error out on custom op")
552566

567+
def get_num_nondecomposed_ops(self, ep, partitioner):
568+
# count the number of aten ops that the partitioner can delegate
569+
# we do this by running run_decompositions() with the preserved ops given
570+
# to us by the partitioner. Then we count the number of preserved aten ops
571+
# which pass the filter_ops fn given by the partitioner
572+
reference_ep = copy.deepcopy(ep)
573+
aten_ops_not_decomposed, filter_ops = partitioner.ops_to_not_decompose(ep)
574+
reference_decomp_ep = reference_ep.run_decompositions(
575+
decomp_table=_default_decomposition_table(),
576+
_preserve_ops=tuple(aten_ops_not_decomposed),
577+
)
578+
num_non_decomposed_aten_ops = 0
579+
for node in reference_decomp_ep.graph.nodes:
580+
if (
581+
node.op == "call_function"
582+
and node.target in aten_ops_not_decomposed
583+
and (filter_ops(node) if filter_ops else True)
584+
):
585+
num_non_decomposed_aten_ops += 1
586+
return num_non_decomposed_aten_ops
587+
553588
def _test_model_with_non_decomp_partitioner(self, model: torch.nn.Module):
554589
# This is the pre-dispatch export that we will be switching to primarily
555590
# in the near future. The input to _to_edge_transform_and_lower needs to
556591
# be a graph generated by this pre dispatch export.
557592
ep = _export(model, model._get_random_inputs(), pre_dispatch=True)
593+
non_decomp_partitioner = NonDecompTestPartitioner()
594+
595+
num_non_decomposed_aten_ops = self.get_num_nondecomposed_ops(
596+
ep, non_decomp_partitioner
597+
)
598+
599+
# run to_edge_trasnform_and_lower
558600
edge = _to_edge_transform_and_lower(
559601
ep,
560602
compile_config=EdgeCompileConfig(),
561603
partitioner=[NonDecompTestPartitioner()],
562604
)
605+
# Check that non_decomposed_edge_ops are all consumed by the delegate
606+
non_decomposed_edge_ops = (
607+
non_decomp_partitioner.supported_non_decomposed_edge_ops
608+
)
609+
for node in edge.exported_program().graph.nodes:
610+
if node.op == "call_function":
611+
self.assertTrue(node.target not in non_decomposed_edge_ops)
612+
613+
# check that the number of call_delegate_nodes is equal to the number of
614+
# non_decomposed_aten_ops we found above
615+
num_call_delegates = 0
563616
for node in edge.exported_program().graph_module.graph.nodes:
564617
# There should only be a single call_function node in the graph
565618
# and that should be a call_delegate node.
566-
if node.op == "call_function" and node.target != operator.getitem:
567-
self.assertEqual(
568-
node.target, torch.ops.higher_order.executorch_call_delegate
569-
)
619+
if (
620+
node.op == "call_function"
621+
and node.target == torch.ops.higher_order.executorch_call_delegate
622+
):
623+
num_call_delegates += 1
624+
625+
self.assertEqual(num_call_delegates, num_non_decomposed_aten_ops)
570626

571627
def test_to_edge_transform_and_lower(self):
572628
self._test_model_with_non_decomp_partitioner(TestLinear())
@@ -577,6 +633,8 @@ def test_to_edge_transform_and_lower(self):
577633

578634
self._test_model_with_non_decomp_partitioner(TestUpsample())
579635

636+
self._test_model_with_non_decomp_partitioner(TestLSTM())
637+
580638
def test_to_edge_transform_and_lower_with_exception(self):
581639
class TestLinear(torch.nn.Module):
582640
def __init__(self):

0 commit comments

Comments
 (0)