Skip to content

Commit 0f840e6

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Add AoT apis (#345)
Summary: Added long term aot api for executorch. Consists of 3 stages based on the dialect the program is in, ATen Edge Executorch Unified api for multiple entry point and single entry point Reviewed By: mergennachin, kimishpatel Differential Revision: D49163694
1 parent dea872c commit 0f840e6

File tree

10 files changed

+704
-341
lines changed

10 files changed

+704
-341
lines changed

exir/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@
1818
from executorch.exir.program import (
1919
_to_edge,
2020
edge_to_executorch_passes,
21+
EdgeProgramManager,
2122
ExecutorchProgram,
23+
ExecutorchProgramManager,
2224
ExirExportedProgram,
2325
multi_method_program_to_executorch,
2426
MultiMethodExecutorchProgram,
2527
MultiMethodExirExportedProgram,
28+
to_edge,
2629
)
2730
from executorch.exir.tracer import ExirDynamoConfig
2831
from torch._export import ( # lots of people are doing from exir import CallSpec, ExportGraphSignature, ExportedProgram which seems wrong
@@ -45,6 +48,9 @@
4548
"ExecutorchProgram",
4649
"ExportGraphSignature",
4750
"_to_edge",
51+
"to_edge",
52+
"EdgeProgramManager",
53+
"ExecutorchProgramManager",
4854
"edge_to_executorch_passes",
4955
"MultiMethodExirExportedProgram",
5056
"MultiMethodExecutorchProgram",

exir/backend/backend_api.py

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
import logging
99
from contextlib import contextmanager
1010
from functools import singledispatch
11-
from typing import Dict, Generator, List, Type, Union
11+
from typing import Generator, List, Type
1212

1313
import torch
14-
from executorch.exir import MultiMethodExirExportedProgram
1514

1615
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
1716
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -319,59 +318,3 @@ def to_backend(
319318
copy.deepcopy(edge_program.equality_constraints),
320319
copy.deepcopy(edge_program.module_call_graph),
321320
)
322-
323-
324-
def to_backend_multiple(
325-
multi_method_program: MultiMethodExirExportedProgram,
326-
partitioner: Union[Dict[str, Type[TPartitioner]], Type[TPartitioner]],
327-
) -> MultiMethodExirExportedProgram:
328-
"""
329-
Returns a semantically-equivalent program to the one given as input (represented
330-
as a graph module in Edge dialect), but with portions of each method in the
331-
program targeted for delegation as determined by the partitioner.
332-
333-
Args:
334-
MultiMethodExirExportedProgram: A multiple method exported program in Edge dialect.
335-
336-
partitioner: The partitioner can either be a Partitioner subclass, or a
337-
dictionary mapping method names to Partitioner subclass. If it is a
338-
Partitioner subclass, all methods in the given multi-method exported
339-
program will be lowered using the given partitioner. If it is a
340-
dictionary, only method names specified in the dictionary will be
341-
lowered with the given partitioner.
342-
343-
THe Partitioner subclass is in charge with tagging portions of the
344-
input program for delegation. A valid partitioner must have
345-
partition_tags: Dict[str, DelegationSpec], where each key is a tag
346-
name and the nodes with same tag will be fused a one subgraph and
347-
delegated to backend specififed in delegation spec.
348-
349-
Returns:
350-
MultiMethodExirExportedProgram: The input program, with some portions
351-
targeted for delegation in each method of the program.
352-
"""
353-
if not (isinstance(partitioner, dict) or issubclass(partitioner, Partitioner)):
354-
raise TypeError(
355-
"partitioner should either be a dictionary of method names to"
356-
+ "partitioner subclass, or a partitioner subclass."
357-
)
358-
359-
method_name_to_delegated_program = {}
360-
for method_name, prog in multi_method_program.methods().items():
361-
if isinstance(partitioner, dict):
362-
if method_name in partitioner:
363-
method_name_to_delegated_program[method_name] = prog
364-
method_name_to_delegated_program[
365-
method_name
366-
].exported_program = to_backend(
367-
prog.exported_program, partitioner[method_name]
368-
)
369-
else:
370-
method_name_to_delegated_program[method_name] = prog
371-
else:
372-
method_name_to_delegated_program[method_name] = prog
373-
method_name_to_delegated_program[method_name].exported_program = to_backend(
374-
prog.exported_program, partitioner
375-
)
376-
377-
return MultiMethodExirExportedProgram(method_name_to_delegated_program)

exir/backend/test/test_backends.py

Lines changed: 1 addition & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,7 @@
1010

1111
import executorch.exir as exir
1212
import torch
13-
import torch.fx as fx
14-
from executorch.exir import multi_method_program_to_executorch
15-
from executorch.exir.backend.backend_api import (
16-
LoweredBackendModule,
17-
to_backend,
18-
to_backend_multiple,
19-
)
13+
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
2014
from executorch.exir.backend.compile_spec_schema import CompileSpec
2115
from executorch.exir.backend.partitioner import (
2216
DelegationSpec,
@@ -1235,137 +1229,3 @@ def forward(self, x: List[torch.Tensor]):
12351229

12361230
gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
12371231
gm(*inputs)
1238-
1239-
def test_lower_multiple(self) -> None:
1240-
class MultipleMethodModule(torch.nn.Module):
1241-
def __init__(self) -> None:
1242-
super().__init__()
1243-
1244-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1245-
return x + y * y
1246-
1247-
def method1(self, x: torch.Tensor) -> torch.Tensor:
1248-
return x + x - x
1249-
1250-
def method2(
1251-
self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
1252-
) -> torch.Tensor:
1253-
return x + y - z
1254-
1255-
module = MultipleMethodModule()
1256-
method_name_to_args = {
1257-
"forward": (torch.rand(2, 2), torch.rand(2, 2)),
1258-
"method1": (torch.rand(2, 2),),
1259-
"method2": (torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)),
1260-
}
1261-
1262-
multi_method_prog = exir.capture_multiple(
1263-
module, method_name_to_args, exir.CaptureConfig()
1264-
).to_edge()
1265-
1266-
lowered_multi_method_prog = to_backend_multiple(
1267-
multi_method_prog, AddMulPartitionerDemo
1268-
)
1269-
1270-
for method_name, args in method_name_to_args.items():
1271-
exported_prog = lowered_multi_method_prog.find_method(method_name)
1272-
self.assertIsNotNone(exported_prog)
1273-
exported_gm = exported_prog.exported_program.graph_module
1274-
self.assertIsInstance(exported_gm, fx.GraphModule)
1275-
1276-
eager_method = getattr(module, method_name)
1277-
eager_results = eager_method(*args)
1278-
exported_results = exported_gm(*args)
1279-
self.assertTrue(torch.allclose(eager_results, exported_results[0]))
1280-
1281-
add_nodes = [
1282-
node
1283-
for node in exported_gm.graph.nodes
1284-
if node.op == "call_function"
1285-
and node.target == exir_ops.edge.aten.add.Tensor
1286-
]
1287-
self.assertEqual(len(add_nodes), 0)
1288-
1289-
lowered_submods = get_lowered_submodules(exported_gm)
1290-
self.assertEqual(len(lowered_submods), 1)
1291-
1292-
_ = multi_method_program_to_executorch(lowered_multi_method_prog)
1293-
1294-
def test_lower_multiple_selective(self) -> None:
1295-
class MultipleMethodModule(torch.nn.Module):
1296-
def __init__(self) -> None:
1297-
super().__init__()
1298-
1299-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1300-
return x + y * y
1301-
1302-
def method1(self, x: torch.Tensor) -> torch.Tensor:
1303-
return x + x - x
1304-
1305-
def method2(
1306-
self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
1307-
) -> torch.Tensor:
1308-
return x + y - z
1309-
1310-
module = MultipleMethodModule()
1311-
method_name_to_args = {
1312-
"forward": (torch.rand(2, 2), torch.rand(2, 2)),
1313-
"method1": (torch.rand(2, 2),),
1314-
"method2": (torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)),
1315-
}
1316-
1317-
multi_method_prog = exir.capture_multiple(
1318-
module, method_name_to_args, exir.CaptureConfig()
1319-
).to_edge()
1320-
1321-
method_name_to_partitioners = {
1322-
"forward": AddMulPartitionerDemo,
1323-
"method1": AddMulPartitionerDemo,
1324-
}
1325-
lowered_multi_method_prog = to_backend_multiple(
1326-
multi_method_prog, method_name_to_partitioners
1327-
)
1328-
1329-
for method_name, args in method_name_to_args.items():
1330-
if method_name == "method2":
1331-
break
1332-
1333-
exported_prog = lowered_multi_method_prog.find_method(method_name)
1334-
self.assertIsNotNone(exported_prog)
1335-
exported_gm = exported_prog.exported_program.graph_module
1336-
self.assertIsInstance(exported_gm, fx.GraphModule)
1337-
1338-
eager_method = getattr(module, method_name)
1339-
eager_results = eager_method(*args)
1340-
exported_results = exported_gm(*args)
1341-
self.assertTrue(torch.allclose(eager_results, exported_results[0]))
1342-
1343-
add_nodes = [
1344-
node
1345-
for node in exported_gm.graph.nodes
1346-
if node.op == "call_function"
1347-
and node.target == exir_ops.edge.aten.add.Tensor
1348-
]
1349-
self.assertEqual(len(add_nodes), 0)
1350-
1351-
lowered_submods = get_lowered_submodules(exported_gm)
1352-
self.assertEqual(len(lowered_submods), 1)
1353-
1354-
# Check that method2 had nothing lowered
1355-
method2_prog = lowered_multi_method_prog.find_method("method2")
1356-
self.assertIsNotNone(method2_prog)
1357-
method2_gm = method2_prog.exported_program.graph_module
1358-
self.assertIsInstance(method2_gm, fx.GraphModule)
1359-
add_nodes = [
1360-
node
1361-
for node in method2_gm.graph.nodes
1362-
if node.op == "call_function"
1363-
and node.target == exir_ops.edge.aten.add.Tensor
1364-
]
1365-
self.assertEqual(len(add_nodes), 1)
1366-
1367-
lowered_submods = get_lowered_submodules(method2_gm)
1368-
self.assertEqual(len(lowered_submods), 0)
1369-
1370-
# Check we can export to executorch properly
1371-
_ = multi_method_program_to_executorch(lowered_multi_method_prog)

0 commit comments

Comments
 (0)