Skip to content

Commit d24c5fd

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Add AoT apis (#345)
Summary: Added long term aot api for executorch. Consists of 3 stages based on the dialect the program is in, ATen Edge Executorch Unified api for multiple entry point and single entry point Differential Revision: D49163694
1 parent 7f395fd commit d24c5fd

File tree

10 files changed

+703
-337
lines changed

10 files changed

+703
-337
lines changed

exir/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@
1818
from executorch.exir.program import (
1919
_to_edge,
2020
edge_to_executorch_passes,
21+
EdgeProgramManager,
2122
ExecutorchProgram,
23+
ExecutorchProgramManager,
2224
ExirExportedProgram,
2325
multi_method_program_to_executorch,
2426
MultiMethodExecutorchProgram,
2527
MultiMethodExirExportedProgram,
28+
to_edge,
2629
)
2730
from executorch.exir.tracer import ExirDynamoConfig
2831
from torch._export import ( # lots of people are doing from exir import CallSpec, ExportGraphSignature, ExportedProgram which seems wrong
@@ -45,6 +48,9 @@
4548
"ExecutorchProgram",
4649
"ExportGraphSignature",
4750
"_to_edge",
51+
"to_edge",
52+
"EdgeProgramManager",
53+
"ExecutorchProgramManager",
4854
"edge_to_executorch_passes",
4955
"MultiMethodExirExportedProgram",
5056
"MultiMethodExecutorchProgram",

exir/backend/backend_api.py

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
import logging
99
from contextlib import contextmanager
1010
from functools import singledispatch
11-
from typing import Dict, Generator, List, Type, Union
11+
from typing import Generator, List, Type
1212

1313
import torch
14-
from executorch.exir import MultiMethodExirExportedProgram
1514

1615
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
1716
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -320,59 +319,3 @@ def to_backend(
320319
copy.deepcopy(edge_program.equality_constraints),
321320
copy.deepcopy(edge_program.module_call_graph),
322321
)
323-
324-
325-
def to_backend_multiple(
326-
multi_method_program: MultiMethodExirExportedProgram,
327-
partitioner: Union[Dict[str, Type[TPartitioner]], Type[TPartitioner]],
328-
) -> MultiMethodExirExportedProgram:
329-
"""
330-
Returns a semantically-equivalent program to the one given as input (represented
331-
as a graph module in Edge dialect), but with portions of each method in the
332-
program targeted for delegation as determined by the partitioner.
333-
334-
Args:
335-
MultiMethodExirExportedProgram: A multiple method exported program in Edge dialect.
336-
337-
partitioner: The partitioner can either be a Partitioner subclass, or a
338-
dictionary mapping method names to Partitioner subclass. If it is a
339-
Partitioner subclass, all methods in the given multi-method exported
340-
program will be lowered using the given partitioner. If it is a
341-
dictionary, only method names specified in the dictionary will be
342-
lowered with the given partitioner.
343-
344-
THe Partitioner subclass is in charge with tagging portions of the
345-
input program for delegation. A valid partitioner must have
346-
partition_tags: Dict[str, DelegationSpec], where each key is a tag
347-
name and the nodes with same tag will be fused a one subgraph and
348-
delegated to backend specififed in delegation spec.
349-
350-
Returns:
351-
MultiMethodExirExportedProgram: The input program, with some portions
352-
targeted for delegation in each method of the program.
353-
"""
354-
if not (isinstance(partitioner, dict) or issubclass(partitioner, Partitioner)):
355-
raise TypeError(
356-
"partitioner should either be a dictionary of method names to"
357-
+ "partitioner subclass, or a partitioner subclass."
358-
)
359-
360-
method_name_to_delegated_program = {}
361-
for method_name, prog in multi_method_program.methods().items():
362-
if isinstance(partitioner, dict):
363-
if method_name in partitioner:
364-
method_name_to_delegated_program[method_name] = prog
365-
method_name_to_delegated_program[
366-
method_name
367-
].exported_program = to_backend(
368-
prog.exported_program, partitioner[method_name]
369-
)
370-
else:
371-
method_name_to_delegated_program[method_name] = prog
372-
else:
373-
method_name_to_delegated_program[method_name] = prog
374-
method_name_to_delegated_program[method_name].exported_program = to_backend(
375-
prog.exported_program, partitioner
376-
)
377-
378-
return MultiMethodExirExportedProgram(method_name_to_delegated_program)

exir/backend/test/test_backends.py

Lines changed: 1 addition & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
import torch
1212
import torch.fx as fx
1313
from executorch.exir import multi_method_program_to_executorch
14-
from executorch.exir.backend.backend_api import (
15-
LoweredBackendModule,
16-
to_backend,
17-
to_backend_multiple,
18-
)
14+
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
1915
from executorch.exir.backend.compile_spec_schema import CompileSpec
2016
from executorch.exir.backend.partitioner import (
2117
DelegationSpec,
@@ -1212,137 +1208,3 @@ def forward(self, x: List[torch.Tensor]):
12121208

12131209
gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
12141210
gm(*inputs)
1215-
1216-
def test_lower_multiple(self) -> None:
1217-
class MultipleMethodModule(torch.nn.Module):
1218-
def __init__(self) -> None:
1219-
super().__init__()
1220-
1221-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1222-
return x + y * y
1223-
1224-
def method1(self, x: torch.Tensor) -> torch.Tensor:
1225-
return x + x - x
1226-
1227-
def method2(
1228-
self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
1229-
) -> torch.Tensor:
1230-
return x + y - z
1231-
1232-
module = MultipleMethodModule()
1233-
method_name_to_args = {
1234-
"forward": (torch.rand(2, 2), torch.rand(2, 2)),
1235-
"method1": (torch.rand(2, 2),),
1236-
"method2": (torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)),
1237-
}
1238-
1239-
multi_method_prog = exir.capture_multiple(
1240-
module, method_name_to_args, exir.CaptureConfig()
1241-
).to_edge()
1242-
1243-
lowered_multi_method_prog = to_backend_multiple(
1244-
multi_method_prog, AddMulPartitionerDemo
1245-
)
1246-
1247-
for method_name, args in method_name_to_args.items():
1248-
exported_prog = lowered_multi_method_prog.find_method(method_name)
1249-
self.assertIsNotNone(exported_prog)
1250-
exported_gm = exported_prog.exported_program.graph_module
1251-
self.assertIsInstance(exported_gm, fx.GraphModule)
1252-
1253-
eager_method = getattr(module, method_name)
1254-
eager_results = eager_method(*args)
1255-
exported_results = exported_gm(*args)
1256-
self.assertTrue(torch.allclose(eager_results, exported_results[0]))
1257-
1258-
add_nodes = [
1259-
node
1260-
for node in exported_gm.graph.nodes
1261-
if node.op == "call_function"
1262-
and node.target == exir_ops.edge.aten.add.Tensor
1263-
]
1264-
self.assertEqual(len(add_nodes), 0)
1265-
1266-
lowered_submods = get_lowered_submodules(exported_gm)
1267-
self.assertEqual(len(lowered_submods), 1)
1268-
1269-
_ = multi_method_program_to_executorch(lowered_multi_method_prog)
1270-
1271-
def test_lower_multiple_selective(self) -> None:
1272-
class MultipleMethodModule(torch.nn.Module):
1273-
def __init__(self) -> None:
1274-
super().__init__()
1275-
1276-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1277-
return x + y * y
1278-
1279-
def method1(self, x: torch.Tensor) -> torch.Tensor:
1280-
return x + x - x
1281-
1282-
def method2(
1283-
self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
1284-
) -> torch.Tensor:
1285-
return x + y - z
1286-
1287-
module = MultipleMethodModule()
1288-
method_name_to_args = {
1289-
"forward": (torch.rand(2, 2), torch.rand(2, 2)),
1290-
"method1": (torch.rand(2, 2),),
1291-
"method2": (torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)),
1292-
}
1293-
1294-
multi_method_prog = exir.capture_multiple(
1295-
module, method_name_to_args, exir.CaptureConfig()
1296-
).to_edge()
1297-
1298-
method_name_to_partitioners = {
1299-
"forward": AddMulPartitionerDemo,
1300-
"method1": AddMulPartitionerDemo,
1301-
}
1302-
lowered_multi_method_prog = to_backend_multiple(
1303-
multi_method_prog, method_name_to_partitioners
1304-
)
1305-
1306-
for method_name, args in method_name_to_args.items():
1307-
if method_name == "method2":
1308-
break
1309-
1310-
exported_prog = lowered_multi_method_prog.find_method(method_name)
1311-
self.assertIsNotNone(exported_prog)
1312-
exported_gm = exported_prog.exported_program.graph_module
1313-
self.assertIsInstance(exported_gm, fx.GraphModule)
1314-
1315-
eager_method = getattr(module, method_name)
1316-
eager_results = eager_method(*args)
1317-
exported_results = exported_gm(*args)
1318-
self.assertTrue(torch.allclose(eager_results, exported_results[0]))
1319-
1320-
add_nodes = [
1321-
node
1322-
for node in exported_gm.graph.nodes
1323-
if node.op == "call_function"
1324-
and node.target == exir_ops.edge.aten.add.Tensor
1325-
]
1326-
self.assertEqual(len(add_nodes), 0)
1327-
1328-
lowered_submods = get_lowered_submodules(exported_gm)
1329-
self.assertEqual(len(lowered_submods), 1)
1330-
1331-
# Check that method2 had nothing lowered
1332-
method2_prog = lowered_multi_method_prog.find_method("method2")
1333-
self.assertIsNotNone(method2_prog)
1334-
method2_gm = method2_prog.exported_program.graph_module
1335-
self.assertIsInstance(method2_gm, fx.GraphModule)
1336-
add_nodes = [
1337-
node
1338-
for node in method2_gm.graph.nodes
1339-
if node.op == "call_function"
1340-
and node.target == exir_ops.edge.aten.add.Tensor
1341-
]
1342-
self.assertEqual(len(add_nodes), 1)
1343-
1344-
lowered_submods = get_lowered_submodules(method2_gm)
1345-
self.assertEqual(len(lowered_submods), 0)
1346-
1347-
# Check we can export to executorch properly
1348-
_ = multi_method_program_to_executorch(lowered_multi_method_prog)

exir/backend/test/test_backends_lifted.py

Lines changed: 1 addition & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
import torch
1212
import torch.fx as fx
1313
from executorch.exir import multi_method_program_to_executorch
14-
from executorch.exir.backend.backend_api import (
15-
LoweredBackendModule,
16-
to_backend,
17-
to_backend_multiple,
18-
)
14+
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
1915
from executorch.exir.backend.compile_spec_schema import CompileSpec
2016
from executorch.exir.backend.partitioner import (
2117
DelegationSpec,
@@ -1257,137 +1253,3 @@ def forward(self, x: List[torch.Tensor]):
12571253

12581254
gm = exir.capture(ComposedM(), inputs, get_testing_capture_config()).to_edge()
12591255
gm(*inputs)
1260-
1261-
def test_lower_multiple(self) -> None:
1262-
class MultipleMethodModule(torch.nn.Module):
1263-
def __init__(self) -> None:
1264-
super().__init__()
1265-
1266-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1267-
return x + y * y
1268-
1269-
def method1(self, x: torch.Tensor) -> torch.Tensor:
1270-
return x + x - x
1271-
1272-
def method2(
1273-
self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
1274-
) -> torch.Tensor:
1275-
return x + y - z
1276-
1277-
module = MultipleMethodModule()
1278-
method_name_to_args = {
1279-
"forward": (torch.rand(2, 2), torch.rand(2, 2)),
1280-
"method1": (torch.rand(2, 2),),
1281-
"method2": (torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)),
1282-
}
1283-
1284-
multi_method_prog = exir.capture_multiple(
1285-
module, method_name_to_args, get_testing_capture_config()
1286-
).to_edge()
1287-
1288-
lowered_multi_method_prog = to_backend_multiple(
1289-
multi_method_prog, AddMulPartitionerDemo
1290-
)
1291-
1292-
for method_name, args in method_name_to_args.items():
1293-
exported_prog = lowered_multi_method_prog.find_method(method_name)
1294-
self.assertIsNotNone(exported_prog)
1295-
exported_gm = exported_prog.exported_program.graph_module
1296-
self.assertIsInstance(exported_gm, fx.GraphModule)
1297-
1298-
eager_method = getattr(module, method_name)
1299-
eager_results = eager_method(*args)
1300-
exported_results = exported_gm(*args)
1301-
self.assertTrue(torch.allclose(eager_results, exported_results[0]))
1302-
1303-
add_nodes = [
1304-
node
1305-
for node in exported_gm.graph.nodes
1306-
if node.op == "call_function"
1307-
and node.target == exir_ops.edge.aten.add.Tensor
1308-
]
1309-
self.assertEqual(len(add_nodes), 0)
1310-
1311-
lowered_submods = get_lowered_submodules(exported_gm)
1312-
self.assertEqual(len(lowered_submods), 1)
1313-
1314-
_ = multi_method_program_to_executorch(lowered_multi_method_prog)
1315-
1316-
def test_lower_multiple_selective(self) -> None:
1317-
class MultipleMethodModule(torch.nn.Module):
1318-
def __init__(self) -> None:
1319-
super().__init__()
1320-
1321-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1322-
return x + y * y
1323-
1324-
def method1(self, x: torch.Tensor) -> torch.Tensor:
1325-
return x + x - x
1326-
1327-
def method2(
1328-
self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
1329-
) -> torch.Tensor:
1330-
return x + y - z
1331-
1332-
module = MultipleMethodModule()
1333-
method_name_to_args = {
1334-
"forward": (torch.rand(2, 2), torch.rand(2, 2)),
1335-
"method1": (torch.rand(2, 2),),
1336-
"method2": (torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)),
1337-
}
1338-
1339-
multi_method_prog = exir.capture_multiple(
1340-
module, method_name_to_args, get_testing_capture_config()
1341-
).to_edge()
1342-
1343-
method_name_to_partitioners = {
1344-
"forward": AddMulPartitionerDemo,
1345-
"method1": AddMulPartitionerDemo,
1346-
}
1347-
lowered_multi_method_prog = to_backend_multiple(
1348-
multi_method_prog, method_name_to_partitioners
1349-
)
1350-
1351-
for method_name, args in method_name_to_args.items():
1352-
if method_name == "method2":
1353-
break
1354-
1355-
exported_prog = lowered_multi_method_prog.find_method(method_name)
1356-
self.assertIsNotNone(exported_prog)
1357-
exported_gm = exported_prog.exported_program.graph_module
1358-
self.assertIsInstance(exported_gm, fx.GraphModule)
1359-
1360-
eager_method = getattr(module, method_name)
1361-
eager_results = eager_method(*args)
1362-
exported_results = exported_gm(*args)
1363-
self.assertTrue(torch.allclose(eager_results, exported_results[0]))
1364-
1365-
add_nodes = [
1366-
node
1367-
for node in exported_gm.graph.nodes
1368-
if node.op == "call_function"
1369-
and node.target == exir_ops.edge.aten.add.Tensor
1370-
]
1371-
self.assertEqual(len(add_nodes), 0)
1372-
1373-
lowered_submods = get_lowered_submodules(exported_gm)
1374-
self.assertEqual(len(lowered_submods), 1)
1375-
1376-
# Check that method2 had nothing lowered
1377-
method2_prog = lowered_multi_method_prog.find_method("method2")
1378-
self.assertIsNotNone(method2_prog)
1379-
method2_gm = method2_prog.exported_program.graph_module
1380-
self.assertIsInstance(method2_gm, fx.GraphModule)
1381-
add_nodes = [
1382-
node
1383-
for node in method2_gm.graph.nodes
1384-
if node.op == "call_function"
1385-
and node.target == exir_ops.edge.aten.add.Tensor
1386-
]
1387-
self.assertEqual(len(add_nodes), 1)
1388-
1389-
lowered_submods = get_lowered_submodules(method2_gm)
1390-
self.assertEqual(len(lowered_submods), 0)
1391-
1392-
# Check we can export to executorch properly
1393-
_ = multi_method_program_to_executorch(lowered_multi_method_prog)

0 commit comments

Comments
 (0)