Skip to content

Commit 09e215d

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Clean up lowering APIs
Summary: Changes some types in `export_to_edge`, add an `export_to_cadence` function, and refactor some calls. Mostly preemptive work before adding a lot of AoT passes. Differential Revision: D57579738
1 parent a707550 commit 09e215d

File tree

4 files changed

+71
-35
lines changed

4 files changed

+71
-35
lines changed

backends/cadence/aot/TARGETS

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
8+
9+
oncall("odai_jarvis")
10+
11+
python_library(
12+
name = "compiler",
13+
srcs = [
14+
"compiler.py",
15+
],
16+
deps = [
17+
":passes",
18+
"//caffe2:torch",
19+
"//executorch/exir:lib",
20+
],
21+
)
22+
23+
python_library(
24+
name = "passes",
25+
srcs = [
26+
"passes.py",
27+
],
28+
deps = [
29+
"//executorch/exir:pass_base",
30+
"//executorch/exir/dialects:lib",
31+
],
32+
)

backends/cadence/aot/compiler.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,20 @@
99

1010
import torch
1111

12+
from executorch.backends.cadence.aot.passes import (
13+
ReplacePT2DequantWithCadenceDequant,
14+
ReplacePT2QuantWithCadenceQuant,
15+
)
1216
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
1317

1418
from torch.export import export
1519
from torch.export.exported_program import ExportedProgram
1620

1721

22+
# Export the model and lower it to an ExportedProgram (in aten IR)
1823
def export_program(
1924
model: torch.nn.Module,
20-
inputs: Any,
25+
inputs: Tuple[Any, ...],
2126
) -> ExportedProgram:
2227
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
2328

@@ -37,12 +42,14 @@ def export_program(
3742
return export(model, inputs)
3843

3944

40-
# Export the model and lower it it edge IR.
45+
# Export the model and lower it to an EdgeProgramManager (in edge IR).
4146
def export_to_edge(
4247
model: torch.nn.Module,
43-
inputs: Any,
48+
inputs: Tuple[Any, ...],
4449
dump_graphs: bool = False,
45-
) -> Tuple[EdgeProgramManager, ExportedProgram]:
50+
) -> EdgeProgramManager:
51+
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
52+
4653
# Export the model into an ExportedProgram.
4754
expo_program = export_program(model, inputs)
4855

@@ -51,12 +58,32 @@ def export_to_edge(
5158

5259
# Call to_edge to convert the graph to edge IR.
5360
edge_prog_manager = to_edge(
54-
expo_program, compile_config=EdgeCompileConfig(_check_ir_validity=False)
61+
expo_program,
62+
compile_config=EdgeCompileConfig(
63+
_check_ir_validity=False, _skip_dim_order=True
64+
),
5565
)
5666

5767
if dump_graphs:
5868
logging.info(
5969
f"Edge graph:\n{edge_prog_manager.exported_program().graph_module.graph}"
6070
)
6171

62-
return edge_prog_manager, expo_program
72+
return edge_prog_manager
73+
74+
75+
# Export the model and lower it to an EdgeProgramManager (in edge IR), and
76+
# apply passes specific to Cadence DSP execution.
77+
def export_to_cadence(
78+
model: torch.nn.Module,
79+
inputs: Tuple[Any, ...],
80+
dump_graphs: bool = False,
81+
) -> EdgeProgramManager:
82+
edge_program_manager = export_to_edge(model, inputs)
83+
84+
# Run a couple required passes for quant/dequant ops
85+
cadence_program_manager = edge_program_manager.transform(
86+
[ReplacePT2QuantWithCadenceQuant(), ReplacePT2DequantWithCadenceDequant()]
87+
)
88+
89+
return cadence_program_manager

backends/cadence/aot/export_example.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,7 @@
1313
import os
1414
from typing import Any, Tuple
1515

16-
from executorch.backends.cadence.aot.compiler import export_to_edge
17-
from executorch.backends.cadence.aot.passes import (
18-
ReplacePT2DequantWithCadenceDequant,
19-
ReplacePT2QuantWithCadenceQuant,
20-
)
16+
from executorch.backends.cadence.aot.compiler import export_to_cadence, export_to_edge
2117
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
2218
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
2319
from executorch.exir import ExecutorchProgramManager
@@ -68,13 +64,11 @@ def export_model(
6864
patterns = [q.pattern for q in quantizer.quantizers]
6965
QuantFusion(patterns)(converted_model)
7066

71-
# Get edge program (note: the name will change to export_to_cadence in future PRs)
72-
edge_prog_manager, expo_prog = export_to_edge(converted_model, example_inputs)
67+
# Get edge program
68+
edge_prog_manager = export_to_edge(converted_model, example_inputs)
7369

74-
# Run a couple required passes for quant/dequant ops
75-
cadence_prog_manager = edge_prog_manager.transform(
76-
[ReplacePT2QuantWithCadenceQuant(), ReplacePT2DequantWithCadenceDequant()]
77-
)
70+
# Get edge program after Cadence specific passes
71+
cadence_prog_manager = export_to_cadence(converted_model, example_inputs)
7872

7973
exec_prog = cadence_prog_manager.to_executorch()
8074

@@ -84,7 +78,6 @@ def export_model(
8478

8579
# Print some information to terminal
8680
print_ops_info(
87-
expo_prog.graph_module,
8881
edge_prog_manager.exported_program().graph_module,
8982
cadence_prog_manager.exported_program().graph_module,
9083
)

backends/cadence/aot/utils.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,29 +71,15 @@ def get_ops_count(graph_module: torch.fx.GraphModule) -> Dict[str, int]:
7171
# from export, from to_edge, and from Jarvis. Print the available
7272
# implementations for each op, and error out if the op is not supported.
7373
def print_ops_info(
74-
export_gm: torch.fx.GraphModule,
7574
to_edge_gm: torch.fx.GraphModule,
7675
jarvis_gm: torch.fx.GraphModule,
7776
):
78-
export_ops_count = get_ops_count(export_gm)
7977
to_edge_ops_count = get_ops_count(to_edge_gm)
8078
jarvis_ops_count = get_ops_count(jarvis_gm)
8179

82-
# De-duplicate the "<op>" and "<op>_copy" ops
83-
keys_to_delete_and_add = []
84-
for k1 in export_ops_count:
85-
for k2 in {**to_edge_ops_count, **jarvis_ops_count}:
86-
if k2.startswith(k1):
87-
keys_to_delete_and_add.append((k1, k2))
88-
break
89-
90-
for k in keys_to_delete_and_add:
91-
export_ops_count[k[1]] = export_ops_count[k[0]]
92-
del export_ops_count[k[0]]
93-
9480
removed_ops = []
9581
# Get the counts of the ops that are removed from the final graph
96-
for k in {**export_ops_count, **to_edge_ops_count}:
82+
for k in to_edge_ops_count:
9783
if k not in jarvis_ops_count:
9884
removed_ops.append(k)
9985

@@ -103,7 +89,6 @@ def print_ops_info(
10389
op,
10490
jarvis_ops_count[op],
10591
to_edge_ops_count[op] if op in to_edge_ops_count else 0,
106-
export_ops_count[op] if op in export_ops_count else 0,
10792
]
10893
for op in jarvis_ops_count
10994
]
@@ -115,7 +100,6 @@ def print_ops_info(
115100
op,
116101
0,
117102
to_edge_ops_count[op] if op in to_edge_ops_count else 0,
118-
export_ops_count[op] if op in export_ops_count else 0,
119103
]
120104
for op in removed_ops
121105
]

0 commit comments

Comments
 (0)