Skip to content

Commit a39ea29

Browse files
authored
add remove ops to oss and callsites, [cadence][8/X] add reorder ops to oss and callsites, [cadence][9/X] add replace ops to oss and callsites, [cadence][10/X] merge passes with replace remove ops passes and update default pass order...
Differential Revision: D66264166 Pull Request resolved: #6993
1 parent 792ef43 commit a39ea29

15 files changed

+7349
-333
lines changed

backends/cadence/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
## Supported DSPs (in progress)
44
- HiFi Audio
5-
- ...
5+
- Fusion G3
66

77
## Tutorial
88

backends/cadence/aot/TARGETS

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ python_library(
3939
":passes",
4040
":utils",
4141
":ops_registrations",
42+
":replace_ops",
4243
"//caffe2:torch",
4344
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
4445
"//executorch/backends/cadence/aot/quantizer:quantizer",
@@ -74,12 +75,14 @@ python_library(
7475
":utils",
7576
":fuse_ops",
7677
":simplify_ops",
78+
":replace_ops",
79+
":reorder_ops",
80+
":remove_ops",
7781
"//caffe2:torch",
7882
"//executorch/exir:pass_base",
7983
"//executorch/exir/dialects:lib",
8084
"//executorch/exir/passes:lib",
8185
"//executorch/exir/passes:spec_prop_pass",
82-
"//executorch/backends/transforms:remove_clone_ops"
8386
],
8487
)
8588

@@ -180,6 +183,63 @@ python_library(
180183
],
181184
)
182185

186+
python_library(
187+
name = "remove_ops",
188+
srcs = [
189+
"remove_ops.py",
190+
],
191+
typing = True,
192+
deps = [
193+
"//caffe2:torch",
194+
"//executorch/backends/cadence/aot:pass_utils",
195+
"//executorch/backends/cadence/aot:simplify_ops",
196+
"//executorch/exir:pass_base",
197+
"//executorch/exir/dialects:lib",
198+
"//executorch/exir/dialects/edge:lib",
199+
"//executorch/exir/passes:spec_prop_pass",
200+
"//executorch/backends/transforms:remove_clone_ops"
201+
],
202+
)
203+
204+
python_library(
205+
name = "reorder_ops",
206+
srcs = [
207+
"reorder_ops.py",
208+
],
209+
typing = True,
210+
deps = [
211+
"//caffe2:torch",
212+
"//executorch/backends/cadence/aot:compiler_utils",
213+
"//executorch/backends/cadence/aot:pass_utils",
214+
"//executorch/backends/cadence/aot:utils",
215+
"//executorch/exir:pass_base",
216+
"//executorch/exir:tensor",
217+
"//executorch/exir/dialects:lib",
218+
"//executorch/exir/dialects/edge:lib",
219+
],
220+
)
221+
222+
python_library(
223+
name = "replace_ops",
224+
srcs = [
225+
"replace_ops.py",
226+
],
227+
typing = True,
228+
deps = [
229+
":pass_utils",
230+
"//caffe2:torch",
231+
"//executorch/backends/cadence/aot:compiler_utils",
232+
"//executorch/backends/cadence/aot:fuse_ops",
233+
"//executorch/backends/cadence/aot:pass_utils",
234+
"//executorch/backends/cadence/aot:remove_ops",
235+
"//executorch/backends/cadence/aot:utils",
236+
"//executorch/exir:pass_base",
237+
"//executorch/exir/dialects:lib",
238+
"//executorch/exir/dialects/edge:lib",
239+
"//executorch/exir/passes:spec_prop_pass",
240+
],
241+
)
242+
183243
python_unittest(
184244
name = "test_graph_builder",
185245
srcs = [
@@ -196,3 +256,101 @@ python_unittest(
196256
":ops_registrations"
197257
],
198258
)
259+
260+
python_unittest(
261+
name = "test_replace_ops_passes",
262+
srcs = [
263+
"tests/test_replace_ops_passes.py",
264+
],
265+
supports_static_listing = False,
266+
typing = True,
267+
deps = [
268+
"fbsource//third-party/pypi/parameterized:parameterized",
269+
":compiler",
270+
":replace_ops",
271+
"//caffe2:torch",
272+
"//executorch/backends/cadence/aot:compiler",
273+
"//executorch/backends/cadence/aot:graph_builder",
274+
"//executorch/backends/cadence/aot:pass_utils",
275+
"//executorch/exir:pass_base",
276+
"//executorch/exir/dialects:lib",
277+
"//executorch/exir/passes:lib",
278+
],
279+
)
280+
281+
python_unittest(
282+
name = "test_fusion_ops_passes",
283+
srcs = [
284+
"tests/test_fusion_ops_passes.py",
285+
],
286+
typing = True,
287+
deps = [
288+
":compiler",
289+
"//caffe2:torch",
290+
"//executorch/backends/cadence/aot:compiler",
291+
"//executorch/backends/cadence/aot:fuse_ops",
292+
"//executorch/backends/cadence/aot:graph_builder",
293+
"//executorch/backends/cadence/aot:ops_registrations",
294+
"//executorch/backends/cadence/aot:pass_utils",
295+
"//executorch/exir/dialects:lib",
296+
"//executorch/exir/dialects/edge:lib",
297+
],
298+
)
299+
300+
python_unittest(
301+
name = "test_remove_ops_passes",
302+
srcs = [
303+
"tests/test_remove_ops_passes.py",
304+
],
305+
supports_static_listing = False,
306+
typing = True,
307+
deps = [
308+
"fbsource//third-party/pypi/parameterized:parameterized",
309+
"fbsource//third-party/pypi/pyre-extensions:pyre-extensions",
310+
":compiler",
311+
"//caffe2:torch",
312+
"//executorch/backends/cadence/aot:compiler",
313+
"//executorch/backends/cadence/aot:ops_registrations",
314+
"//executorch/backends/cadence/aot:pass_utils",
315+
"//executorch/backends/cadence/aot:remove_ops",
316+
"//executorch/backends/cadence/aot/quantizer:quantizer",
317+
"//executorch/exir/dialects:lib",
318+
],
319+
)
320+
321+
python_unittest(
322+
name = "test_simplify_ops_passes",
323+
srcs = [
324+
"tests/test_simplify_ops_passes.py",
325+
],
326+
supports_static_listing = False,
327+
typing = True,
328+
deps = [
329+
"fbsource//third-party/pypi/parameterized:parameterized",
330+
"//caffe2:torch",
331+
"//executorch/backends/cadence/aot:compiler",
332+
"//executorch/backends/cadence/aot:ops_registrations",
333+
"//executorch/backends/cadence/aot:pass_utils",
334+
"//executorch/backends/cadence/aot:simplify_ops",
335+
"//executorch/exir/dialects:lib",
336+
],
337+
)
338+
339+
python_unittest(
340+
name = "test_reorder_ops_passes",
341+
srcs = [
342+
"tests/test_reorder_ops_passes.py",
343+
],
344+
typing = True,
345+
deps = [
346+
":compiler",
347+
":pass_utils",
348+
"//caffe2:torch",
349+
"//executorch/backends/cadence/aot:compiler",
350+
"//executorch/backends/cadence/aot:fuse_ops",
351+
"//executorch/backends/cadence/aot:ops_registrations",
352+
"//executorch/backends/cadence/aot:pass_utils",
353+
"//executorch/backends/cadence/aot:reorder_ops",
354+
"//executorch/exir/dialects:lib",
355+
],
356+
)

backends/cadence/aot/compiler.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212

1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
15-
16-
from executorch.backends.cadence.aot.passes import ReplaceSafeSoftmaxWithSoftmax
1715
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
1816
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
17+
18+
from executorch.backends.cadence.aot.replace_ops import ReplaceSafeSoftmaxWithSoftmax
1919
from executorch.backends.cadence.aot.utils import model_gm_has_SDPA, model_is_quantized
2020
from executorch.backends.transforms.decompose_sdpa import (
2121
DecomposeScaledDotProductAttention,
@@ -194,9 +194,6 @@ def export_to_edge(
194194
return edge_prog_manager
195195

196196

197-
# Export the model and lower it to an EdgeProgramManager (in edge IR), and
198-
# apply passes specific to Cadence DSP execution. Return both to print the
199-
# differences.
200197
def export_to_cadence(
201198
model: torch.nn.Module,
202199
inputs: tuple[object, ...],
@@ -216,6 +213,25 @@ def export_to_cadence(
216213
return cadence_prog_manager
217214

218215

216+
def quantize_and_export_to_cadence(
217+
model: torch.nn.Module,
218+
inputs: tuple[object, ...],
219+
dump_graphs: bool = False,
220+
opt_level: int = 1,
221+
) -> EdgeProgramManager:
222+
quantized_model = quantize_pt2(model, inputs)
223+
224+
return export_to_cadence(
225+
quantized_model,
226+
inputs,
227+
opt_level=opt_level,
228+
dump_graphs=dump_graphs,
229+
)
230+
231+
232+
# Export the model and lower it to an EdgeProgramManager (in edge IR), and
233+
# apply passes specific to Cadence DSP execution. Return both to print the
234+
# differences.
219235
def export_to_executorch_gen_etrecord(
220236
model: torch.nn.Module,
221237
inputs: tuple[object, ...],

backends/cadence/aot/pass_utils.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# pyre-strict
44

55
from dataclasses import dataclass
6-
from typing import Callable, Optional, Set, Union
6+
from typing import Callable, List, Optional, Set, Union
77

88
import torch
99
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
@@ -50,7 +50,7 @@ def get_all_available_cadence_passes() -> Set[ExportPass]:
5050
return set(ALL_CADENCE_PASSES.keys())
5151

5252

53-
# Create a new filter to filter out relevant passes from all Jarvis passes.
53+
# Create a new filter to filter out relevant passes from all passes.
5454
def create_cadence_pass_filter(
5555
opt_level: int, debug: bool = False
5656
) -> Callable[[ExportPass], bool]:
@@ -98,3 +98,47 @@ def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target)
9898
if node.op == "call_function" and node.target == target:
9999
total += 1
100100
return total
101+
102+
103+
# Testing utils
104+
# Return the compute/function nodes in the graph
105+
def get_compute_nodes_in_gm(graph_module: torch.fx.GraphModule) -> List[torch.fx.Node]:
106+
nodes = []
107+
for x in graph_module.graph.nodes:
108+
if x.op == "call_function":
109+
if isinstance(x.target, torch._ops.OpOverload):
110+
nodes.append(x.target.overloadpacket)
111+
elif isinstance(x.target, EdgeOpOverload):
112+
nodes.append(get_edge_overload_packet(x.target))
113+
return nodes
114+
115+
116+
# Return true if there is no edge from a node with target pred_target to a
117+
# node with target succ_target in the graph.
118+
def nodes_not_connected_in_gm(
119+
graph_module: torch.fx.GraphModule,
120+
pred_target: torch.fx.Node,
121+
succ_target: torch.fx.Node,
122+
) -> bool:
123+
for node in graph_module.graph.nodes:
124+
if node.target != pred_target:
125+
continue
126+
for user in node.users:
127+
if user.target == succ_target:
128+
return False
129+
return True
130+
131+
132+
# Returns true if there is no instance of a node with target succ_target
133+
# positioned immediately after a node with target pred_target in the graph
134+
def nodes_not_adjacent_in_gm(
135+
graph_module: torch.fx.GraphModule,
136+
pred_target: torch.fx.Node,
137+
succ_target: torch.fx.Node,
138+
) -> bool:
139+
for node in graph_module.graph.nodes:
140+
if node.target != pred_target:
141+
continue
142+
if node.next.target == succ_target:
143+
return False
144+
return True

0 commit comments

Comments
 (0)