Skip to content

Add AoT apis #345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions exir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
from executorch.exir.program import (
_to_edge,
edge_to_executorch_passes,
EdgeProgramManager,
ExecutorchProgram,
ExecutorchProgramManager,
ExirExportedProgram,
multi_method_program_to_executorch,
MultiMethodExecutorchProgram,
MultiMethodExirExportedProgram,
to_edge,
)
from executorch.exir.tracer import ExirDynamoConfig
from torch._export import ( # lots of people are doing from exir import CallSpec, ExportGraphSignature, ExportedProgram which seems wrong
Expand All @@ -45,6 +48,9 @@
"ExecutorchProgram",
"ExportGraphSignature",
"_to_edge",
"to_edge",
"EdgeProgramManager",
"ExecutorchProgramManager",
"edge_to_executorch_passes",
"MultiMethodExirExportedProgram",
"MultiMethodExecutorchProgram",
Expand Down
59 changes: 1 addition & 58 deletions exir/backend/backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import logging
from contextlib import contextmanager
from functools import singledispatch
from typing import Dict, Generator, List, Type, Union
from typing import Generator, List, Type

import torch
from executorch.exir import MultiMethodExirExportedProgram

from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.compile_spec_schema import CompileSpec
Expand Down Expand Up @@ -319,59 +318,3 @@ def to_backend(
copy.deepcopy(edge_program.equality_constraints),
copy.deepcopy(edge_program.module_call_graph),
)


def to_backend_multiple(
multi_method_program: MultiMethodExirExportedProgram,
partitioner: Union[Dict[str, Type[TPartitioner]], Type[TPartitioner]],
) -> MultiMethodExirExportedProgram:
"""
Returns a semantically-equivalent program to the one given as input (represented
as a graph module in Edge dialect), but with portions of each method in the
program targeted for delegation as determined by the partitioner.

Args:
MultiMethodExirExportedProgram: A multiple method exported program in Edge dialect.

partitioner: The partitioner can either be a Partitioner subclass, or a
dictionary mapping method names to Partitioner subclass. If it is a
Partitioner subclass, all methods in the given multi-method exported
program will be lowered using the given partitioner. If it is a
dictionary, only method names specified in the dictionary will be
lowered with the given partitioner.

THe Partitioner subclass is in charge with tagging portions of the
input program for delegation. A valid partitioner must have
partition_tags: Dict[str, DelegationSpec], where each key is a tag
name and the nodes with same tag will be fused a one subgraph and
delegated to backend specififed in delegation spec.

Returns:
MultiMethodExirExportedProgram: The input program, with some portions
targeted for delegation in each method of the program.
"""
if not (isinstance(partitioner, dict) or issubclass(partitioner, Partitioner)):
raise TypeError(
"partitioner should either be a dictionary of method names to"
+ "partitioner subclass, or a partitioner subclass."
)

method_name_to_delegated_program = {}
for method_name, prog in multi_method_program.methods().items():
if isinstance(partitioner, dict):
if method_name in partitioner:
method_name_to_delegated_program[method_name] = prog
method_name_to_delegated_program[
method_name
].exported_program = to_backend(
prog.exported_program, partitioner[method_name]
)
else:
method_name_to_delegated_program[method_name] = prog
else:
method_name_to_delegated_program[method_name] = prog
method_name_to_delegated_program[method_name].exported_program = to_backend(
prog.exported_program, partitioner
)

return MultiMethodExirExportedProgram(method_name_to_delegated_program)
142 changes: 1 addition & 141 deletions exir/backend/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,7 @@

import executorch.exir as exir
import torch
import torch.fx as fx
from executorch.exir import multi_method_program_to_executorch
from executorch.exir.backend.backend_api import (
LoweredBackendModule,
to_backend,
to_backend_multiple,
)
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Expand Down Expand Up @@ -1235,137 +1229,3 @@ def forward(self, x: List[torch.Tensor]):

gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
gm(*inputs)

def test_lower_multiple(self) -> None:
class MultipleMethodModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y * y

def method1(self, x: torch.Tensor) -> torch.Tensor:
return x + x - x

def method2(
self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
) -> torch.Tensor:
return x + y - z

module = MultipleMethodModule()
method_name_to_args = {
"forward": (torch.rand(2, 2), torch.rand(2, 2)),
"method1": (torch.rand(2, 2),),
"method2": (torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)),
}

multi_method_prog = exir.capture_multiple(
module, method_name_to_args, exir.CaptureConfig()
).to_edge()

lowered_multi_method_prog = to_backend_multiple(
multi_method_prog, AddMulPartitionerDemo
)

for method_name, args in method_name_to_args.items():
exported_prog = lowered_multi_method_prog.find_method(method_name)
self.assertIsNotNone(exported_prog)
exported_gm = exported_prog.exported_program.graph_module
self.assertIsInstance(exported_gm, fx.GraphModule)

eager_method = getattr(module, method_name)
eager_results = eager_method(*args)
exported_results = exported_gm(*args)
self.assertTrue(torch.allclose(eager_results, exported_results[0]))

add_nodes = [
node
for node in exported_gm.graph.nodes
if node.op == "call_function"
and node.target == exir_ops.edge.aten.add.Tensor
]
self.assertEqual(len(add_nodes), 0)

lowered_submods = get_lowered_submodules(exported_gm)
self.assertEqual(len(lowered_submods), 1)

_ = multi_method_program_to_executorch(lowered_multi_method_prog)

def test_lower_multiple_selective(self) -> None:
class MultipleMethodModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y * y

def method1(self, x: torch.Tensor) -> torch.Tensor:
return x + x - x

def method2(
self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
) -> torch.Tensor:
return x + y - z

module = MultipleMethodModule()
method_name_to_args = {
"forward": (torch.rand(2, 2), torch.rand(2, 2)),
"method1": (torch.rand(2, 2),),
"method2": (torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)),
}

multi_method_prog = exir.capture_multiple(
module, method_name_to_args, exir.CaptureConfig()
).to_edge()

method_name_to_partitioners = {
"forward": AddMulPartitionerDemo,
"method1": AddMulPartitionerDemo,
}
lowered_multi_method_prog = to_backend_multiple(
multi_method_prog, method_name_to_partitioners
)

for method_name, args in method_name_to_args.items():
if method_name == "method2":
break

exported_prog = lowered_multi_method_prog.find_method(method_name)
self.assertIsNotNone(exported_prog)
exported_gm = exported_prog.exported_program.graph_module
self.assertIsInstance(exported_gm, fx.GraphModule)

eager_method = getattr(module, method_name)
eager_results = eager_method(*args)
exported_results = exported_gm(*args)
self.assertTrue(torch.allclose(eager_results, exported_results[0]))

add_nodes = [
node
for node in exported_gm.graph.nodes
if node.op == "call_function"
and node.target == exir_ops.edge.aten.add.Tensor
]
self.assertEqual(len(add_nodes), 0)

lowered_submods = get_lowered_submodules(exported_gm)
self.assertEqual(len(lowered_submods), 1)

# Check that method2 had nothing lowered
method2_prog = lowered_multi_method_prog.find_method("method2")
self.assertIsNotNone(method2_prog)
method2_gm = method2_prog.exported_program.graph_module
self.assertIsInstance(method2_gm, fx.GraphModule)
add_nodes = [
node
for node in method2_gm.graph.nodes
if node.op == "call_function"
and node.target == exir_ops.edge.aten.add.Tensor
]
self.assertEqual(len(add_nodes), 1)

lowered_submods = get_lowered_submodules(method2_gm)
self.assertEqual(len(lowered_submods), 0)

# Check we can export to executorch properly
_ = multi_method_program_to_executorch(lowered_multi_method_prog)
Loading