Skip to content

Change partition function to take exported program #340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions backends/qnnpack/partition/qnnpack_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Partitioner,
PartitionResult,
)
from torch.export import ExportedProgram
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -51,9 +52,11 @@ def check_partitions(partitions: Union[dict, list]) -> None:
else:
log.info(f"Found {pl} subgraphs to be partitioned.")

def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
raise NotImplementedError("This is not meant to be used directly.")
return PartitionResult(tagged_graph=graph_module, partition_tags={})
return PartitionResult(
tagged_exported_program=exported_program, partition_tags={}
)


class _SingleOpDelegatePartitioner(_BasePartitioner):
Expand All @@ -75,16 +78,18 @@ def __init__(
self.transforms = transforms

# override
def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
# TODO delete this since we are not allowed to do this
if self.transforms is not None:
for transform in self.transforms: # pyre-ignore
graph_module.graph = transform(graph_module.graph)
exported_program = exported_program._transform(transform)

matches = [
match
for matches in (
SubgraphMatcher(pattern, ignore_literals=True).match(graph_module.graph)
SubgraphMatcher(pattern, ignore_literals=True).match(
exported_program.graph
)
for pattern in self.patterns
)
for match in matches
Expand Down Expand Up @@ -130,7 +135,9 @@ def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
partition_tags[delegation_tag] = self.delegation_spec
tag_mapping[delegation_tag] = match_set

return PartitionResult(tagged_graph=graph_module, partition_tags=partition_tags)
return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)


class QnnpackPartitioner(_SingleOpDelegatePartitioner):
Expand Down
29 changes: 18 additions & 11 deletions backends/xnnpack/partition/xnnpack_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
PartitionResult,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export import ExportedProgram
from torch.fx.passes.infra.partitioner import Partition
from torch.fx.passes.operator_support import OperatorSupportBase

Expand Down Expand Up @@ -403,16 +404,18 @@ def tag_nodes(self, partitions: List[Partition]) -> Dict[str, DelegationSpec]:
return partition_tags

# override
def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
"""
Run the partitioner on the given graph module, then tag each partition
with its delegation tag (and partition id)
"""
partitions = self.generate_partitions(graph_module)
partitions = self.generate_partitions(exported_program.graph_module)
partition_tags: Dict[str, DelegationSpec] = {}
if self.check_partitions(partitions):
partition_tags = self.tag_nodes(partitions)
return PartitionResult(tagged_graph=graph_module, partition_tags=partition_tags)
return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)


# TODO: Merge XnnpackQuantizedPartitioner and XnnpackFloatingPointPartitioner
Expand Down Expand Up @@ -788,20 +791,22 @@ def tag_nodes(self, partitions: List[Partition]) -> Dict[str, DelegationSpec]:

# override
def _partition(
self, graph_module: torch.fx.GraphModule, quant: Optional[bool]
self, exported_program: ExportedProgram, quant: Optional[bool]
) -> PartitionResult:
"""
Run the partitioner on the given graph module, then tag each partition
with its delegation tag (and partition id)
"""
partitions = self.generate_partitions(graph_module, quant)
partitions = self.generate_partitions(exported_program.graph_module, quant)
partition_tags: Dict[str, DelegationSpec] = {}
if self.check_partitions(partitions):
partition_tags = self.tag_nodes(partitions)
return PartitionResult(tagged_graph=graph_module, partition_tags=partition_tags)
return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)

def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
ret: PartitionResult = self._partition(graph_module, self.quant)
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
ret: PartitionResult = self._partition(exported_program, self.quant)
return ret


Expand All @@ -814,7 +819,7 @@ def __init__(
super().__init__(supported_modules, supported_ops)

# override
def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
"""
Run the partitioner on the given graph module, then tag each partition with its delegegation tag (and partition id)

Expand All @@ -826,10 +831,12 @@ def partition(self, graph_module: torch.fx.GraphModule) -> PartitionResult:
id=next(partition_id),
nodes=set(match),
)
for match in self.get_module_partitions(graph_module)
for match in self.get_module_partitions(exported_program.graph_module)
]
partition_tags: Dict[str, DelegationSpec] = {}

if self.check_partitions(partitions):
partition_tags = self.tag_nodes(partitions)
return PartitionResult(tagged_graph=graph_module, partition_tags=partition_tags)
return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)
4 changes: 2 additions & 2 deletions docs/website/docs/tutorials/backend_delegate.md
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ class Backend_1_2_Partitioner(Partitioner):
self.delegation_spec_2 = DelegationSpec("Backend2", [])

def partition(
self, edge_graph_module: torch.fx.GraphModule
self, exported_program: ExportedProgram
) -> PartitionResult:
partition_tags: Dict[str, DelegationSpec] = {}
# Tag all nodes in the first partiton to backend 1
Expand All @@ -388,7 +388,7 @@ class Backend_1_2_Partitioner(Partitioner):
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec_2
return PartitionResult(
tagged_graph=edge_graph_module, partition_tags=partition_tags
tagged_exported_program=exported_program, partition_tags=partition_tags
)
```

Expand Down
13 changes: 10 additions & 3 deletions examples/example_quantizer_and_delegate/example_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.graph_module import get_control_flow_submodules
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.export import ExportedProgram
from torch.fx.passes.operator_support import OperatorSupportBase


Expand All @@ -46,7 +47,9 @@ def is_node_supported(self, _submodules, node: torch.fx.Node) -> bool:

self.dequant_quant_support = DequantQuantOperatorSupport()

def partition(self, edge_graph_module: torch.fx.GraphModule) -> PartitionResult:
def _partition_graph_module(
self, edge_graph_module: torch.fx.GraphModule
) -> Dict[str, DelegationSpec]:
partition_tags: Dict[str, DelegationSpec] = {}
partition_nodes = []
for pattern in self.patterns:
Expand Down Expand Up @@ -77,8 +80,12 @@ def partition(self, edge_graph_module: torch.fx.GraphModule) -> PartitionResult:
partition_tags[delegation_tag] = self.delegation_spec

for _, submodule, _ in get_control_flow_submodules(edge_graph_module):
self.partition(submodule)
self._partition_graph_module(submodule)

return partition_tags

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
partition_tag = self._partition_graph_module(exported_program.graph_module)
return PartitionResult(
tagged_graph=edge_graph_module, partition_tags=partition_tags
tagged_exported_program=exported_program, partition_tags=partition_tag
)
15 changes: 7 additions & 8 deletions exir/backend/backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
LoweredBackendModule,
)
from executorch.exir.pass_base import ExportPass
from torch._export.exported_program import ExportedProgram
from torch.export import ExportedProgram


@singledispatch
Expand Down Expand Up @@ -281,18 +281,17 @@ def to_backend(
Returns:
ExportedProgram: The input program, with some portions targeted for delegation.
"""
edge_graph_module = edge_program.graph_module
copied_graph_module = copy.deepcopy(edge_graph_module)
copied_edge_program = copy.deepcopy(edge_program)
# Call the partitioner on the given graph module
partitioner_instance: Partitioner = partitioner()
partitioner_result = partitioner_instance(copied_graph_module)
tagged_graph_module = partitioner_result.tagged_graph
partitioner_result = partitioner_instance(copied_edge_program)
tagged_exported_program = partitioner_result.tagged_exported_program

# Check that the partitioner did not modify the original graph
if _ENABLE_VALIDATION:
assert is_identical_graph(
tagged_graph_module,
edge_graph_module,
tagged_exported_program.graph_module,
edge_program.graph_module,
), f"The partitioner {partitioner} should not modify the graph module"
else:
logging.warning("Disabled validating the partitioner.")
Expand All @@ -302,7 +301,7 @@ def to_backend(
), f"Partitioner {partitioner} needs a `partition_tags` field containing a mapping of tags to delegate spec"

tagged_graph_module = _partition_and_lower(
tagged_graph_module, partitioner_result, edge_program
tagged_exported_program.graph_module, partitioner_result, edge_program
)

# TODO(angelayi): Update this signature in a less manual way (maybe through
Expand Down
29 changes: 14 additions & 15 deletions exir/backend/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from dataclasses import dataclass
from typing import Dict, List, NamedTuple, TypeVar

import torch.fx as fx

from executorch.exir.backend.backend_details import enforcedmethod
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.export import ExportedProgram


class DelegationSpec(NamedTuple):
Expand All @@ -22,20 +21,20 @@ class DelegationSpec(NamedTuple):
@dataclass
class PartitionResult:
"""
tagged_graph: the graph with nodes that intend to be delegated containing a "DelegationSpec" metadata
tagged_exported_program: the graph with nodes that intend to be delegated containing a "DelegationSpec" metadata
partition_tags: A dictionary that will be used to keep track of the tags and it's corresponding DelegationSpec. The tag is defined by users and used
in the node.meta.
"""

tagged_graph: fx.GraphModule
tagged_exported_program: ExportedProgram
partition_tags: Dict[str, DelegationSpec]


class Partitioner(ABC):
"""
Defines a callable interface for partitioning an exported module (i.e. a program) for
Defines a callable interface for partitioning an exported program for
backend delegation.
A partitioner implementation would receive an exported module, determine what portions of
A partitioner implementation would receive an exported program, determine what portions of
the it can be delegated to certain backend (though a partitioner can target multiple
backends as well), and return the PartitionResult including:
- the same input module with specific nodes in the input graph tagged for delegation
Expand All @@ -51,31 +50,31 @@ class Partitioner(ABC):
the same format.

Args:
edge_graph_module: A module in Edge dialect to be partitioned for backend delegation.
exported_program: An ExportedProgram in Edge dialect to be partitioned for backend delegation.
"""

def __call__(self, edge_graph_module: fx.GraphModule) -> PartitionResult:
return self.partition(edge_graph_module)
def __call__(self, exported_program: ExportedProgram) -> PartitionResult:
return self.partition(exported_program)

@enforcedmethod
@abstractmethod
def partition(self, edge_graph_module: fx.GraphModule) -> PartitionResult:
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
"""
Returns the input exported program with newly created sub-Modules encapsulating
specific portions of the input "tagged" for delegation.

The specific implementation is free to decide how existing computation in the
input Module should be delegated to one or even more than one specific
input exported program should be delegated to one or even more than one specific
backends.

The contract is stringent in that:
* Each node that is intended to be delegated must be tagged
* No change in the original input Module (GraphModule) representation can take
* No change in the original input exported program (ExportedProgram) representation can take
place other than adding sub-Modules for encapsulating existing portions of the
input Module and the associated metadata for tagging.
input exported program and the associated metadata for tagging.

Args:
edge_graph_module: A module in Edge dialect to be partitioned for backend delegation.
exported_program: An ExportedProgram in Edge dialect to be partitioned for backend delegation.

Returns:
PartitionResult: includes the tagged graph and the delegation spec to indicate what backend_id and compile_spec is used for each node and the tag created by the backend developers.
Expand All @@ -84,5 +83,5 @@ def partition(self, edge_graph_module: fx.GraphModule) -> PartitionResult:


# Define Type variables to allow instantiate an instance a subclass of Partitioner
# in to_backend(edge_graph_module: torch.fx.GraphModule, partitioner: Type[TPartitioner])
# in to_backend(edge_exported_program: ExportedProgram, partitioner: Type[TPartitioner])
TPartitioner = TypeVar("TPartitioner", bound=Partitioner)
8 changes: 5 additions & 3 deletions exir/backend/test/demos/rpc/executor_backend_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from executorch.exir.backend.test.backend_with_compiler_demo import (
BackendWithCompilerDemo,
)
from torch.export import ExportedProgram
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase


Expand Down Expand Up @@ -49,10 +50,10 @@ def __init__(self) -> None:
self.op_support = any_chain(AnyOperatorSupport(), AnyDelegateSupport())
self.delegation_spec = DelegationSpec("ExecutorBackend", [])

def partition(self, edge_graph_module: torch.fx.GraphModule) -> PartitionResult:
def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult:
partition_tags = {}
partition_list = generate_pattern_op_partitions(
edge_graph_module, op_support=self.op_support
edge_exported_program.graph_module, op_support=self.op_support
)
for partition in partition_list:
for node in partition.nodes:
Expand All @@ -67,5 +68,6 @@ def partition(self, edge_graph_module: torch.fx.GraphModule) -> PartitionResult:
node.args[0].meta["delegation_tag"] = delegation_tag

return PartitionResult(
tagged_graph=edge_graph_module, partition_tags=partition_tags
tagged_exported_program=edge_exported_program,
partition_tags=partition_tags,
)
16 changes: 9 additions & 7 deletions exir/backend/test/hta_partitioner_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
PartitionResult,
)
from executorch.exir.backend.test.qnn_backend_demo import QnnBackend
from torch.fx import GraphModule
from torch.export import ExportedProgram
from torch.fx.passes.infra.partitioner import Partition


Expand Down Expand Up @@ -191,15 +191,17 @@ def generate_partition_list(self, graph_module) -> List[Partition]:

return flat_proposed_partitions_with_unique_id

def partition(self, graph_module: GraphModule) -> PartitionResult:
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
partition_tags = {}
partition_list = self.generate_partition_list(graph_module)
partition_list = self.generate_partition_list(exported_program.graph_module)
for partition in partition_list:
for node in partition.nodes:
delegation_tag = f"tag{partition.id}"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec
return PartitionResult(tagged_graph=graph_module, partition_tags=partition_tags)
return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)


@final
Expand Down Expand Up @@ -268,16 +270,16 @@ def forward(self, x_raw, h, c):
backend_id = QnnBackend.__name__
self.delegation_spec = DelegationSpec(backend_id, [])

def partition(self, edge_graph_module: torch.fx.GraphModule) -> PartitionResult:
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
partition_tags = {}
partition_list = generate_pattern_op_partitions(
edge_graph_module, patterns=self.patterns
exported_program.graph_module, patterns=self.patterns
)
for partition in partition_list:
for node in partition.nodes:
delegation_tag = f"tag{partition.id}"
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec
return PartitionResult(
tagged_graph=edge_graph_module, partition_tags=partition_tags
tagged_exported_program=exported_program, partition_tags=partition_tags
)
Loading