Skip to content

Commit ce21efe

Browse files
cccclaifacebook-github-bot
authored andcommitted
Change partition function to take exported program (#340)
Summary: Pull Request resolved: #340 right now to_backend takes `ExportedProgram` and for partitioner sometimes it'll need to access parameters/buffers to decide how to partition. this change is needed for partition to take graph with lifted params Reviewed By: mergennachin Differential Revision: D49171496 fbshipit-source-id: 25758a47989d8a6507396cf6ec6d5bda58fe7de3
1 parent 8209cb3 commit ce21efe

File tree

14 files changed

+136
-86
lines changed

14 files changed

+136
-86
lines changed

backends/qnnpack/partition/qnnpack_partitioner.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Partitioner,
2525
PartitionResult,
2626
)
27+
from torch.export import ExportedProgram
2728
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
2829

2930
logging.basicConfig(level=logging.INFO)
@@ -51,9 +52,11 @@ def check_partitions(partitions: Union[dict, list]) -> None:
5152
else:
5253
log.info(f"Found {pl} subgraphs to be partitioned.")
5354

54-
def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
55+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
5556
raise NotImplementedError("This is not meant to be used directly.")
56-
return PartitionResult(tagged_graph=graph_module, partition_tags={})
57+
return PartitionResult(
58+
tagged_exported_program=exported_program, partition_tags={}
59+
)
5760

5861

5962
class _SingleOpDelegatePartitioner(_BasePartitioner):
@@ -75,16 +78,18 @@ def __init__(
7578
self.transforms = transforms
7679

7780
# override
78-
def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
81+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
7982
# TODO delete this since we are not allowed to do this
8083
if self.transforms is not None:
8184
for transform in self.transforms: # pyre-ignore
82-
graph_module.graph = transform(graph_module.graph)
85+
exported_program = exported_program._transform(transform)
8386

8487
matches = [
8588
match
8689
for matches in (
87-
SubgraphMatcher(pattern, ignore_literals=True).match(graph_module.graph)
90+
SubgraphMatcher(pattern, ignore_literals=True).match(
91+
exported_program.graph
92+
)
8893
for pattern in self.patterns
8994
)
9095
for match in matches
@@ -130,7 +135,9 @@ def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
130135
partition_tags[delegation_tag] = self.delegation_spec
131136
tag_mapping[delegation_tag] = match_set
132137

133-
return PartitionResult(tagged_graph=graph_module, partition_tags=partition_tags)
138+
return PartitionResult(
139+
tagged_exported_program=exported_program, partition_tags=partition_tags
140+
)
134141

135142

136143
class QnnpackPartitioner(_SingleOpDelegatePartitioner):

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
PartitionResult,
3333
)
3434
from executorch.exir.dialects._ops import ops as exir_ops
35+
from torch.export import ExportedProgram
3536
from torch.fx.passes.infra.partitioner import Partition
3637
from torch.fx.passes.operator_support import OperatorSupportBase
3738

@@ -403,16 +404,18 @@ def tag_nodes(self, partitions: List[Partition]) -> Dict[str, DelegationSpec]:
403404
return partition_tags
404405

405406
# override
406-
def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
407+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
407408
"""
408409
Run the partitioner on the given graph module, then tag each partition
409410
with its delegation tag (and partition id)
410411
"""
411-
partitions = self.generate_partitions(graph_module)
412+
partitions = self.generate_partitions(exported_program.graph_module)
412413
partition_tags: Dict[str, DelegationSpec] = {}
413414
if self.check_partitions(partitions):
414415
partition_tags = self.tag_nodes(partitions)
415-
return PartitionResult(tagged_graph=graph_module, partition_tags=partition_tags)
416+
return PartitionResult(
417+
tagged_exported_program=exported_program, partition_tags=partition_tags
418+
)
416419

417420

418421
# TODO: Merge XnnpackQuantizedPartitioner and XnnpackFloatingPointPartitioner
@@ -788,20 +791,22 @@ def tag_nodes(self, partitions: List[Partition]) -> Dict[str, DelegationSpec]:
788791

789792
# override
790793
def _partition(
791-
self, graph_module: torch.fx.GraphModule, quant: Optional[bool]
794+
self, exported_program: ExportedProgram, quant: Optional[bool]
792795
) -> PartitionResult:
793796
"""
794797
Run the partitioner on the given graph module, then tag each partition
795798
with its delegation tag (and partition id)
796799
"""
797-
partitions = self.generate_partitions(graph_module, quant)
800+
partitions = self.generate_partitions(exported_program.graph_module, quant)
798801
partition_tags: Dict[str, DelegationSpec] = {}
799802
if self.check_partitions(partitions):
800803
partition_tags = self.tag_nodes(partitions)
801-
return PartitionResult(tagged_graph=graph_module, partition_tags=partition_tags)
804+
return PartitionResult(
805+
tagged_exported_program=exported_program, partition_tags=partition_tags
806+
)
802807

803-
def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
804-
ret: PartitionResult = self._partition(graph_module, self.quant)
808+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
809+
ret: PartitionResult = self._partition(exported_program, self.quant)
805810
return ret
806811

807812

@@ -814,7 +819,7 @@ def __init__(
814819
super().__init__(supported_modules, supported_ops)
815820

816821
# override
817-
def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
822+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
818823
"""
819824
Run the partitioner on the given graph module, then tag each partition with its delegegation tag (and partition id)
820825
@@ -826,10 +831,12 @@ def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
826831
id=next(partition_id),
827832
nodes=set(match),
828833
)
829-
for match in self.get_module_partitions(graph_module)
834+
for match in self.get_module_partitions(exported_program.graph_module)
830835
]
831836
partition_tags: Dict[str, DelegationSpec] = {}
832837

833838
if self.check_partitions(partitions):
834839
partition_tags = self.tag_nodes(partitions)
835-
return PartitionResult(tagged_graph=graph_module, partition_tags=partition_tags)
840+
return PartitionResult(
841+
tagged_exported_program=exported_program, partition_tags=partition_tags
842+
)

docs/website/docs/tutorials/backend_delegate.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ class Backend_1_2_Partitioner(Partitioner):
373373
self.delegation_spec_2 = DelegationSpec("Backend2", [])
374374

375375
def partition(
376-
self, edge_graph_module: torch.fx.GraphModule
376+
self, exported_program: ExportedProgram
377377
) -> PartitionResult:
378378
partition_tags: Dict[str, DelegationSpec] = {}
379379
# Tag all nodes in the first partiton to backend 1
@@ -388,7 +388,7 @@ class Backend_1_2_Partitioner(Partitioner):
388388
node.meta["delegation_tag"] = delegation_tag
389389
partition_tags[delegation_tag] = self.delegation_spec_2
390390
return PartitionResult(
391-
tagged_graph=edge_graph_module, partition_tags=partition_tags
391+
tagged_exported_program=exported_program, partition_tags=partition_tags
392392
)
393393
```
394394

examples/example_quantizer_and_delegate/example_partitioner.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from executorch.exir.dialects._ops import ops as exir_ops
2525
from executorch.exir.graph_module import get_control_flow_submodules
2626
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
27+
from torch.export import ExportedProgram
2728
from torch.fx.passes.operator_support import OperatorSupportBase
2829

2930

@@ -46,7 +47,9 @@ def is_node_supported(self, _submodules, node: torch.fx.Node) -> bool:
4647

4748
self.dequant_quant_support = DequantQuantOperatorSupport()
4849

49-
def partition(self, edge_graph_module: torch.fx.GraphModule) -> PartitionResult:
50+
def _partition_graph_module(
51+
self, edge_graph_module: torch.fx.GraphModule
52+
) -> Dict[str, DelegationSpec]:
5053
partition_tags: Dict[str, DelegationSpec] = {}
5154
partition_nodes = []
5255
for pattern in self.patterns:
@@ -77,8 +80,12 @@ def partition(self, edge_graph_module: torch.fx.GraphModule) -> PartitionResult:
7780
partition_tags[delegation_tag] = self.delegation_spec
7881

7982
for _, submodule, _ in get_control_flow_submodules(edge_graph_module):
80-
self.partition(submodule)
83+
self._partition_graph_module(submodule)
8184

85+
return partition_tags
86+
87+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
88+
partition_tag = self._partition_graph_module(exported_program.graph_module)
8289
return PartitionResult(
83-
tagged_graph=edge_graph_module, partition_tags=partition_tags
90+
tagged_exported_program=exported_program, partition_tags=partition_tag
8491
)

exir/backend/backend_api.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
LoweredBackendModule,
3434
)
3535
from executorch.exir.pass_base import ExportPass
36-
from torch._export.exported_program import ExportedProgram
36+
from torch.export import ExportedProgram
3737

3838

3939
@singledispatch
@@ -281,18 +281,17 @@ def to_backend(
281281
Returns:
282282
ExportedProgram: The input program, with some portions targeted for delegation.
283283
"""
284-
edge_graph_module = edge_program.graph_module
285-
copied_graph_module = copy.deepcopy(edge_graph_module)
284+
copied_edge_program = copy.deepcopy(edge_program)
286285
# Call the partitioner on the given graph module
287286
partitioner_instance: Partitioner = partitioner()
288-
partitioner_result = partitioner_instance(copied_graph_module)
289-
tagged_graph_module = partitioner_result.tagged_graph
287+
partitioner_result = partitioner_instance(copied_edge_program)
288+
tagged_exported_program = partitioner_result.tagged_exported_program
290289

291290
# Check that the partitioner did not modify the original graph
292291
if _ENABLE_VALIDATION:
293292
assert is_identical_graph(
294-
tagged_graph_module,
295-
edge_graph_module,
293+
tagged_exported_program.graph_module,
294+
edge_program.graph_module,
296295
), f"The partitioner {partitioner} should not modify the graph module"
297296
else:
298297
logging.warning("Disabled validating the partitioner.")
@@ -302,7 +301,7 @@ def to_backend(
302301
), f"Partitioner {partitioner} needs a `partition_tags` field containing a mapping of tags to delegate spec"
303302

304303
tagged_graph_module = _partition_and_lower(
305-
tagged_graph_module, partitioner_result, edge_program
304+
tagged_exported_program.graph_module, partitioner_result, edge_program
306305
)
307306

308307
# TODO(angelayi): Update this signature in a less manual way (maybe through

exir/backend/partitioner.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
from dataclasses import dataclass
99
from typing import Dict, List, NamedTuple, TypeVar
1010

11-
import torch.fx as fx
12-
1311
from executorch.exir.backend.backend_details import enforcedmethod
1412
from executorch.exir.backend.compile_spec_schema import CompileSpec
13+
from torch.export import ExportedProgram
1514

1615

1716
class DelegationSpec(NamedTuple):
@@ -22,20 +21,20 @@ class DelegationSpec(NamedTuple):
2221
@dataclass
2322
class PartitionResult:
2423
"""
25-
tagged_graph: the graph with nodes that intend to be delegated containing a "DelegationSpec" metadata
24+
tagged_exported_program: the graph with nodes that intend to be delegated containing a "DelegationSpec" metadata
2625
partition_tags: A dictionary that will be used to keep track of the tags and it's corresponding DelegationSpec. The tag is defined by users and used
2726
in the node.meta.
2827
"""
2928

30-
tagged_graph: fx.GraphModule
29+
tagged_exported_program: ExportedProgram
3130
partition_tags: Dict[str, DelegationSpec]
3231

3332

3433
class Partitioner(ABC):
3534
"""
36-
Defines a callable interface for partitioning an exported module (i.e. a program) for
35+
Defines a callable interface for partitioning an exported program for
3736
backend delegation.
38-
A partitioner implementation would receive an exported module, determine what portions of
37+
A partitioner implementation would receive an exported program, determine what portions of
3938
the it can be delegated to certain backend (though a partitioner can target multiple
4039
backends as well), and return the PartitionResult including:
4140
- the same input module with specific nodes in the input graph tagged for delegation
@@ -51,31 +50,31 @@ class Partitioner(ABC):
5150
the same format.
5251
5352
Args:
54-
edge_graph_module: A module in Edge dialect to be partitioned for backend delegation.
53+
exported_program: An ExportedProgram in Edge dialect to be partitioned for backend delegation.
5554
"""
5655

57-
def __call__(self, edge_graph_module: fx.GraphModule) -> PartitionResult:
58-
return self.partition(edge_graph_module)
56+
def __call__(self, exported_program: ExportedProgram) -> PartitionResult:
57+
return self.partition(exported_program)
5958

6059
@enforcedmethod
6160
@abstractmethod
62-
def partition(self, edge_graph_module: fx.GraphModule) -> PartitionResult:
61+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
6362
"""
6463
Returns the input exported program with newly created sub-Modules encapsulating
6564
specific portions of the input "tagged" for delegation.
6665
6766
The specific implementation is free to decide how existing computation in the
68-
input Module should be delegated to one or even more than one specific
67+
input exported program should be delegated to one or even more than one specific
6968
backends.
7069
7170
The contract is stringent in that:
7271
* Each node that is intended to be delegated must be tagged
73-
* No change in the original input Module (GraphModule) representation can take
72+
* No change in the original input exported program (ExportedProgram) representation can take
7473
place other than adding sub-Modules for encapsulating existing portions of the
75-
input Module and the associated metadata for tagging.
74+
input exported program and the associated metadata for tagging.
7675
7776
Args:
78-
edge_graph_module: A module in Edge dialect to be partitioned for backend delegation.
77+
exported_program: An ExportedProgram in Edge dialect to be partitioned for backend delegation.
7978
8079
Returns:
8180
PartitionResult: includes the tagged graph and the delegation spec to indicate what backend_id and compile_spec is used for each node and the tag created by the backend developers.
@@ -84,5 +83,5 @@ def partition(self, edge_graph_module: fx.GraphModule) -> PartitionResult:
8483

8584

8685
# Define Type variables to allow instantiate an instance a subclass of Partitioner
87-
# in to_backend(edge_graph_module: torch.fx.GraphModule, partitioner: Type[TPartitioner])
86+
# in to_backend(edge_exported_program: ExportedProgram, partitioner: Type[TPartitioner])
8887
TPartitioner = TypeVar("TPartitioner", bound=Partitioner)

exir/backend/test/demos/rpc/executor_backend_partitioner.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from executorch.exir.backend.test.backend_with_compiler_demo import (
2020
BackendWithCompilerDemo,
2121
)
22+
from torch.export import ExportedProgram
2223
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
2324

2425

@@ -49,10 +50,10 @@ def __init__(self) -> None:
4950
self.op_support = any_chain(AnyOperatorSupport(), AnyDelegateSupport())
5051
self.delegation_spec = DelegationSpec("ExecutorBackend", [])
5152

52-
def partition(self, edge_graph_module: torch.fx.GraphModule) -> PartitionResult:
53+
def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult:
5354
partition_tags = {}
5455
partition_list = generate_pattern_op_partitions(
55-
edge_graph_module, op_support=self.op_support
56+
edge_exported_program.graph_module, op_support=self.op_support
5657
)
5758
for partition in partition_list:
5859
for node in partition.nodes:
@@ -67,5 +68,6 @@ def partition(self, edge_graph_module: torch.fx.GraphModule) -> PartitionResult:
6768
node.args[0].meta["delegation_tag"] = delegation_tag
6869

6970
return PartitionResult(
70-
tagged_graph=edge_graph_module, partition_tags=partition_tags
71+
tagged_exported_program=edge_exported_program,
72+
partition_tags=partition_tags,
7173
)

exir/backend/test/hta_partitioner_demo.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
PartitionResult,
2020
)
2121
from executorch.exir.backend.test.qnn_backend_demo import QnnBackend
22-
from torch.fx import GraphModule
22+
from torch.export import ExportedProgram
2323
from torch.fx.passes.infra.partitioner import Partition
2424

2525

@@ -191,15 +191,17 @@ def generate_partition_list(self, graph_module) -> List[Partition]:
191191

192192
return flat_proposed_partitions_with_unique_id
193193

194-
def partition(self, graph_module: GraphModule) -> PartitionResult:
194+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
195195
partition_tags = {}
196-
partition_list = self.generate_partition_list(graph_module)
196+
partition_list = self.generate_partition_list(exported_program.graph_module)
197197
for partition in partition_list:
198198
for node in partition.nodes:
199199
delegation_tag = f"tag{partition.id}"
200200
node.meta["delegation_tag"] = delegation_tag
201201
partition_tags[delegation_tag] = self.delegation_spec
202-
return PartitionResult(tagged_graph=graph_module, partition_tags=partition_tags)
202+
return PartitionResult(
203+
tagged_exported_program=exported_program, partition_tags=partition_tags
204+
)
203205

204206

205207
@final
@@ -268,16 +270,16 @@ def forward(self, x_raw, h, c):
268270
backend_id = QnnBackend.__name__
269271
self.delegation_spec = DelegationSpec(backend_id, [])
270272

271-
def partition(self, edge_graph_module: torch.fx.GraphModule) -> PartitionResult:
273+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
272274
partition_tags = {}
273275
partition_list = generate_pattern_op_partitions(
274-
edge_graph_module, patterns=self.patterns
276+
exported_program.graph_module, patterns=self.patterns
275277
)
276278
for partition in partition_list:
277279
for node in partition.nodes:
278280
delegation_tag = f"tag{partition.id}"
279281
node.meta["delegation_tag"] = delegation_tag
280282
partition_tags[delegation_tag] = self.delegation_spec
281283
return PartitionResult(
282-
tagged_graph=edge_graph_module, partition_tags=partition_tags
284+
tagged_exported_program=exported_program, partition_tags=partition_tags
283285
)

0 commit comments

Comments
 (0)