Skip to content

Commit b1edc3d

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add util to print out ops and frequency (#2983)
Summary: Pull Request resolved: #2983 As titled. Reviewed By: cccclai Differential Revision: D56001227 fbshipit-source-id: cefef12662e03171136f03138fb814d61a28a0f3
1 parent 5b7c4ba commit b1edc3d

File tree

3 files changed

+139
-4
lines changed

3 files changed

+139
-4
lines changed

examples/cadence/aot/compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8-
from typing import Any, Callable
8+
from typing import Any, Callable, Tuple
99

1010
import torch
1111

@@ -48,7 +48,7 @@ def export_to_edge(
4848
inputs: Any,
4949
pt2_quant: bool = False,
5050
dump_graphs: bool = False,
51-
) -> EdgeProgramManager:
51+
) -> Tuple[EdgeProgramManager, ExportedProgram]:
5252
# Export the model into an ExportedProgram.
5353
expo_program = export_program(model, inputs, pt2_quant)
5454

@@ -65,4 +65,4 @@ def export_to_edge(
6565
f"Edge graph:\n{edge_prog_manager.exported_program().graph_module.graph}"
6666
)
6767

68-
return edge_prog_manager
68+
return edge_prog_manager, expo_program

examples/cadence/aot/export_example.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ReplacePT2DequantWithCadenceDequant,
2323
ReplacePT2QuantWithCadenceQuant,
2424
)
25+
from .utils import print_ops_info
2526

2627

2728
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -47,7 +48,9 @@ def export_model(model, example_inputs):
4748
QuantFusion(patterns)(converted_model)
4849

4950
# Get edge program (note: the name will change to export_to_cadence in future PRs)
50-
edge_prog_manager = export_to_edge(converted_model, example_inputs, pt2_quant=True)
51+
edge_prog_manager, expo_prog = export_to_edge(
52+
converted_model, example_inputs, pt2_quant=True
53+
)
5154

5255
# Run a couple required passes for quant/dequant ops
5356
cadence_prog_manager = edge_prog_manager.transform(
@@ -61,5 +64,12 @@ def export_model(model, example_inputs):
6164
f"Final exported graph module:\n{exec_prog.exported_program().graph_module}"
6265
)
6366

67+
# Print some information to terminal
68+
print_ops_info(
69+
expo_prog.graph_module,
70+
edge_prog_manager.exported_program().graph_module,
71+
cadence_prog_manager.exported_program().graph_module,
72+
)
73+
6474
# Save the program as CadenceDemoModel.pte
6575
save_pte_program(exec_prog, "CadenceDemoModel")

examples/cadence/aot/utils.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
8+
import operator
9+
from typing import Dict
10+
711
import torch
12+
from executorch.exir import memory
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
15+
from tabulate import tabulate
816

917

1018
# Get the output size of a 1D convolution given the input size and parameters
@@ -23,3 +31,120 @@ def get_conv1d_output_size(
2331
lout = (L + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
2432

2533
return torch.Size((in_size[0], out_channels, lout))
34+
35+
36+
# Return the overload packet for the edge op
37+
def get_edge_overload_packet(edge_op: EdgeOpOverload) -> EdgeOpOverloadPacket:
38+
edge_op_namespace, edge_op_name = (
39+
edge_op.namespace,
40+
edge_op._schema.name.split("::")[1],
41+
)
42+
edge_op_overload_packet = getattr(
43+
getattr(exir_ops.edge, edge_op_namespace), edge_op_name
44+
)
45+
return edge_op_overload_packet
46+
47+
48+
# Get the frequency list of ops in a graph module
49+
def get_ops_count(graph_module: torch.fx.GraphModule) -> Dict[str, int]:
50+
freq = {}
51+
# Loop over nodes to count the number of times each op occurs
52+
for node in graph_module.graph.nodes:
53+
if node.op == "call_function":
54+
# Ignore getitem, alloc and view cases, we only want actual operations
55+
if (
56+
node.target == operator.getitem
57+
or node.target.__name__ == "alloc"
58+
or node.target == memory.view
59+
):
60+
continue
61+
# If the op is already present, increment the count
62+
if get_edge_overload_packet(node.target).__name__ in freq:
63+
freq[get_edge_overload_packet(node.target).__name__] += 1
64+
# else, add a new entry
65+
else:
66+
freq[get_edge_overload_packet(node.target).__name__] = 1
67+
return freq
68+
69+
70+
# Print the ops and how many times they occur multiple graph modules:
71+
# from export, from to_edge, and from Jarvis. Print the available
72+
# implementations for each op, and error out if the op is not supported.
73+
def print_ops_info(
74+
export_gm: torch.fx.GraphModule,
75+
to_edge_gm: torch.fx.GraphModule,
76+
jarvis_gm: torch.fx.GraphModule,
77+
):
78+
export_ops_count = get_ops_count(export_gm)
79+
to_edge_ops_count = get_ops_count(to_edge_gm)
80+
jarvis_ops_count = get_ops_count(jarvis_gm)
81+
82+
# De-duplicate the "<op>" and "<op>_copy" ops
83+
keys_to_delete_and_add = []
84+
for k1 in export_ops_count:
85+
for k2 in {**to_edge_ops_count, **jarvis_ops_count}:
86+
if k2.startswith(k1):
87+
keys_to_delete_and_add.append((k1, k2))
88+
break
89+
90+
for k in keys_to_delete_and_add:
91+
export_ops_count[k[1]] = export_ops_count[k[0]]
92+
del export_ops_count[k[0]]
93+
94+
removed_ops = []
95+
# Get the counts of the ops that are removed from the final graph
96+
for k in {**export_ops_count, **to_edge_ops_count}:
97+
if k not in jarvis_ops_count:
98+
removed_ops.append(k)
99+
100+
# Create a dict of ops and their counts to pass to tabulate
101+
ops_count = [
102+
[
103+
op,
104+
jarvis_ops_count[op],
105+
to_edge_ops_count[op] if op in to_edge_ops_count else 0,
106+
export_ops_count[op] if op in export_ops_count else 0,
107+
]
108+
for op in jarvis_ops_count
109+
]
110+
sorted_ops_count = sorted(ops_count, key=lambda x: x[1], reverse=True)
111+
112+
# Create a dict of deleted ops and their counts to pass to tabulate
113+
removed_ops_count = [
114+
[
115+
op,
116+
0,
117+
to_edge_ops_count[op] if op in to_edge_ops_count else 0,
118+
export_ops_count[op] if op in export_ops_count else 0,
119+
]
120+
for op in removed_ops
121+
]
122+
123+
# Print the final ops and their counts in a tabular format
124+
logging.info(
125+
tabulate(
126+
sorted_ops_count,
127+
headers=[
128+
"Final Operators ", # one character longer than the longest op name
129+
"Jarvis (Final) Graph",
130+
"To_edge Graph",
131+
"Export Graph",
132+
],
133+
tablefmt="outline",
134+
)
135+
)
136+
137+
# Print the removed ops and their counts in a tabular format (if any)
138+
if removed_ops != []:
139+
logging.info(
140+
tabulate(
141+
removed_ops_count,
142+
headers=[
143+
"Deleted Operators ", # one character longer than the longest op name
144+
"Jarvis (Final) Graph",
145+
"To_edge Graph",
146+
"Export Graph",
147+
],
148+
tablefmt="outline",
149+
)
150+
)

0 commit comments

Comments
 (0)