Skip to content

Commit 6892240

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add util to print out ops and frequency
Summary: As titled. Sample output: ``` test ``` Differential Revision: D56001227
1 parent 28ab306 commit 6892240

File tree

3 files changed

+140
-4
lines changed

3 files changed

+140
-4
lines changed

examples/cadence/aot/compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8-
from typing import Any, Callable
8+
from typing import Any, Callable, Tuple
99

1010
import torch
1111

@@ -48,7 +48,7 @@ def export_to_edge(
4848
inputs: Any,
4949
pt2_quant: bool = False,
5050
dump_graphs: bool = False,
51-
) -> EdgeProgramManager:
51+
) -> Tuple[EdgeProgramManager, ExportedProgram]:
5252
# Export the model into an ExportedProgram.
5353
expo_program = export_program(model, inputs, pt2_quant)
5454

@@ -65,4 +65,4 @@ def export_to_edge(
6565
f"Edge graph:\n{edge_prog_manager.exported_program().graph_module.graph}"
6666
)
6767

68-
return edge_prog_manager
68+
return edge_prog_manager, ExportedProgram

examples/cadence/aot/export_example.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from .meta_registrations import * # noqa
1212

13+
from executorch.examples.cadence.aot.utils import print_ops_info
1314
from torch._export import capture_pre_autograd_graph
1415
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
1516

@@ -47,14 +48,23 @@ def export_model(model, example_inputs):
4748
QuantFusion(patterns)(converted_model)
4849

4950
# Get edge program (note: the name will change to export_to_cadence in future PRs)
50-
edge_prog_manager = export_to_edge(converted_model, example_inputs, pt2_quant=True)
51+
edge_prog_manager, expo_prog = export_to_edge(
52+
converted_model, example_inputs, pt2_quant=True
53+
)
5154

5255
# Run a couple required passes for quant/dequant ops
5356
cadence_prog_manager = edge_prog_manager.transform(
5457
[ReplacePT2QuantWithXtensaQuant(), ReplacePT2DequantWithXtensaDequant()],
5558
check_ir_validity=False,
5659
)
5760

61+
# Print some information to terminal
62+
print_ops_info(
63+
expo_prog.graph_module,
64+
edge_prog_manager.exported_program().graph_module,
65+
cadence_prog_manager.exported_program().graph_module,
66+
)
67+
5868
exec_prog = cadence_prog_manager.to_executorch()
5969

6070
logging.info(

examples/cadence/aot/utils.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import operator
8+
from typing import Dict
9+
710
import torch
11+
from executorch.exir import memory
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
14+
from tabulate import tabulate
815

916

1017
# Get the output size of a 1D convolution given the input size and parameters
@@ -23,3 +30,122 @@ def get_conv1d_output_size(
2330
lout = (L + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
2431

2532
return torch.Size((in_size[0], out_channels, lout))
33+
34+
35+
# Return the overload packet for the edge op
36+
def get_edge_overload_packet(edge_op: EdgeOpOverload) -> EdgeOpOverloadPacket:
37+
edge_op_namespace, edge_op_name = (
38+
edge_op.namespace,
39+
edge_op._schema.name.split("::")[1],
40+
)
41+
edge_op_overload_packet = getattr(
42+
getattr(exir_ops.edge, edge_op_namespace), edge_op_name
43+
)
44+
return edge_op_overload_packet
45+
46+
47+
# Get the frequency list of ops in a graph module
48+
def get_ops_count(graph_module: torch.fx.GraphModule) -> Dict[str, int]:
49+
freq = {}
50+
# Loop over nodes to count the number of times each op occurs
51+
for node in graph_module.graph.nodes:
52+
if node.op == "call_function":
53+
# Ignore getitem, alloc and view cases, we only want actual operations
54+
if (
55+
node.target == operator.getitem
56+
or node.target.__name__ == "alloc"
57+
or node.target == memory.view
58+
):
59+
continue
60+
# If the op is already present, increment the count
61+
if get_edge_overload_packet(node.target).__name__ in freq:
62+
freq[get_edge_overload_packet(node.target).__name__] += 1
63+
# else, add a new entry
64+
else:
65+
freq[get_edge_overload_packet(node.target).__name__] = 1
66+
return freq
67+
68+
69+
# Print the ops and how many times they occur multiple graph modules:
70+
# from export, from to_edge, and from Jarvis. Print the available
71+
# implementations for each op, and error out if the op is not supported.
72+
def print_ops_info(
73+
export_gm: torch.fx.GraphModule,
74+
to_edge_gm: torch.fx.GraphModule,
75+
jarvis_gm: torch.fx.GraphModule,
76+
):
77+
export_ops_count = get_ops_count(export_gm)
78+
to_edge_ops_count = get_ops_count(to_edge_gm)
79+
jarvis_ops_count = get_ops_count(jarvis_gm)
80+
81+
# De-duplicate the "<op>" and "<op>_copy" ops
82+
keys_to_delete_and_add = []
83+
for k1 in export_ops_count:
84+
for k2 in {**to_edge_ops_count, **jarvis_ops_count}:
85+
if k2.startswith(k1):
86+
keys_to_delete_and_add.append((k1, k2))
87+
break
88+
89+
for k in keys_to_delete_and_add:
90+
export_ops_count[k[1]] = export_ops_count[k[0]]
91+
del export_ops_count[k[0]]
92+
93+
removed_ops = []
94+
# Get the counts of the ops that are removed from the final graph
95+
for k in {**export_ops_count, **to_edge_ops_count}:
96+
if k not in jarvis_ops_count:
97+
removed_ops.append(k)
98+
99+
# Create a dict of ops and their counts to pass to tabulate
100+
ops_count = [
101+
[
102+
op,
103+
jarvis_ops_count[op],
104+
to_edge_ops_count[op] if op in to_edge_ops_count else 0,
105+
export_ops_count[op] if op in export_ops_count else 0,
106+
]
107+
for op in jarvis_ops_count
108+
]
109+
sorted_ops_count = sorted(ops_count, key=lambda x: x[1], reverse=True)
110+
111+
# Create a dict of deleted ops and their counts to pass to tabulate
112+
removed_ops_count = [
113+
[
114+
op,
115+
0,
116+
to_edge_ops_count[op] if op in to_edge_ops_count else 0,
117+
export_ops_count[op] if op in export_ops_count else 0,
118+
]
119+
for op in removed_ops
120+
]
121+
122+
# Print the final ops and their counts in a tabular format
123+
print(
124+
tabulate(
125+
sorted_ops_count,
126+
headers=[
127+
"Final Operators ", # one character longer than the longest op name
128+
"Jarvis (Final) Graph",
129+
"To_edge Graph",
130+
"Export Graph",
131+
"Implementation Available",
132+
],
133+
tablefmt="outline",
134+
)
135+
)
136+
137+
# Print the removed ops and their counts in a tabular format (if any)
138+
if removed_ops != []:
139+
print(
140+
tabulate(
141+
removed_ops_count,
142+
headers=[
143+
"Deleted Operators ", # one character longer than the longest op name
144+
"Jarvis (Final) Graph",
145+
"To_edge Graph",
146+
"Export Graph",
147+
"Implementation Available",
148+
],
149+
tablefmt="outline",
150+
)
151+
)

0 commit comments

Comments
 (0)