4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import operator
8
+ from typing import Dict
9
+
7
10
import torch
11
+ from executorch .exir import memory
12
+ from executorch .exir .dialects ._ops import ops as exir_ops
13
+ from executorch .exir .dialects .edge ._ops import EdgeOpOverload , EdgeOpOverloadPacket
14
+ from tabulate import tabulate
8
15
9
16
10
17
# Get the output size of a 1D convolution given the input size and parameters
@@ -23,3 +30,122 @@ def get_conv1d_output_size(
23
30
lout = (L + 2 * padding - dilation * (kernel_size - 1 ) - 1 ) // stride + 1
24
31
25
32
return torch .Size ((in_size [0 ], out_channels , lout ))
33
+
34
+
35
+ # Return the overload packet for the edge op
36
+ def get_edge_overload_packet (edge_op : EdgeOpOverload ) -> EdgeOpOverloadPacket :
37
+ edge_op_namespace , edge_op_name = (
38
+ edge_op .namespace ,
39
+ edge_op ._schema .name .split ("::" )[1 ],
40
+ )
41
+ edge_op_overload_packet = getattr (
42
+ getattr (exir_ops .edge , edge_op_namespace ), edge_op_name
43
+ )
44
+ return edge_op_overload_packet
45
+
46
+
47
+ # Get the frequency list of ops in a graph module
48
+ def get_ops_count (graph_module : torch .fx .GraphModule ) -> Dict [str , int ]:
49
+ freq = {}
50
+ # Loop over nodes to count the number of times each op occurs
51
+ for node in graph_module .graph .nodes :
52
+ if node .op == "call_function" :
53
+ # Ignore getitem, alloc and view cases, we only want actual operations
54
+ if (
55
+ node .target == operator .getitem
56
+ or node .target .__name__ == "alloc"
57
+ or node .target == memory .view
58
+ ):
59
+ continue
60
+ # If the op is already present, increment the count
61
+ if get_edge_overload_packet (node .target ).__name__ in freq :
62
+ freq [get_edge_overload_packet (node .target ).__name__ ] += 1
63
+ # else, add a new entry
64
+ else :
65
+ freq [get_edge_overload_packet (node .target ).__name__ ] = 1
66
+ return freq
67
+
68
+
69
+ # Print the ops and how many times they occur multiple graph modules:
70
+ # from export, from to_edge, and from Jarvis. Print the available
71
+ # implementations for each op, and error out if the op is not supported.
72
+ def print_ops_info (
73
+ export_gm : torch .fx .GraphModule ,
74
+ to_edge_gm : torch .fx .GraphModule ,
75
+ jarvis_gm : torch .fx .GraphModule ,
76
+ ):
77
+ export_ops_count = get_ops_count (export_gm )
78
+ to_edge_ops_count = get_ops_count (to_edge_gm )
79
+ jarvis_ops_count = get_ops_count (jarvis_gm )
80
+
81
+ # De-duplicate the "<op>" and "<op>_copy" ops
82
+ keys_to_delete_and_add = []
83
+ for k1 in export_ops_count :
84
+ for k2 in {** to_edge_ops_count , ** jarvis_ops_count }:
85
+ if k2 .startswith (k1 ):
86
+ keys_to_delete_and_add .append ((k1 , k2 ))
87
+ break
88
+
89
+ for k in keys_to_delete_and_add :
90
+ export_ops_count [k [1 ]] = export_ops_count [k [0 ]]
91
+ del export_ops_count [k [0 ]]
92
+
93
+ removed_ops = []
94
+ # Get the counts of the ops that are removed from the final graph
95
+ for k in {** export_ops_count , ** to_edge_ops_count }:
96
+ if k not in jarvis_ops_count :
97
+ removed_ops .append (k )
98
+
99
+ # Create a dict of ops and their counts to pass to tabulate
100
+ ops_count = [
101
+ [
102
+ op ,
103
+ jarvis_ops_count [op ],
104
+ to_edge_ops_count [op ] if op in to_edge_ops_count else 0 ,
105
+ export_ops_count [op ] if op in export_ops_count else 0 ,
106
+ ]
107
+ for op in jarvis_ops_count
108
+ ]
109
+ sorted_ops_count = sorted (ops_count , key = lambda x : x [1 ], reverse = True )
110
+
111
+ # Create a dict of deleted ops and their counts to pass to tabulate
112
+ removed_ops_count = [
113
+ [
114
+ op ,
115
+ 0 ,
116
+ to_edge_ops_count [op ] if op in to_edge_ops_count else 0 ,
117
+ export_ops_count [op ] if op in export_ops_count else 0 ,
118
+ ]
119
+ for op in removed_ops
120
+ ]
121
+
122
+ # Print the final ops and their counts in a tabular format
123
+ print (
124
+ tabulate (
125
+ sorted_ops_count ,
126
+ headers = [
127
+ "Final Operators " , # one character longer than the longest op name
128
+ "Jarvis (Final) Graph" ,
129
+ "To_edge Graph" ,
130
+ "Export Graph" ,
131
+ "Implementation Available" ,
132
+ ],
133
+ tablefmt = "outline" ,
134
+ )
135
+ )
136
+
137
+ # Print the removed ops and their counts in a tabular format (if any)
138
+ if removed_ops != []:
139
+ print (
140
+ tabulate (
141
+ removed_ops_count ,
142
+ headers = [
143
+ "Deleted Operators " , # one character longer than the longest op name
144
+ "Jarvis (Final) Graph" ,
145
+ "To_edge Graph" ,
146
+ "Export Graph" ,
147
+ "Implementation Available" ,
148
+ ],
149
+ tablefmt = "outline" ,
150
+ )
151
+ )
0 commit comments