4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import logging
8
+ import operator
9
+ from typing import Dict
10
+
7
11
import torch
12
+ from executorch .exir import memory
13
+ from executorch .exir .dialects ._ops import ops as exir_ops
14
+ from executorch .exir .dialects .edge ._ops import EdgeOpOverload , EdgeOpOverloadPacket
15
+ from tabulate import tabulate
8
16
9
17
10
18
# Get the output size of a 1D convolution given the input size and parameters
@@ -23,3 +31,120 @@ def get_conv1d_output_size(
23
31
lout = (L + 2 * padding - dilation * (kernel_size - 1 ) - 1 ) // stride + 1
24
32
25
33
return torch .Size ((in_size [0 ], out_channels , lout ))
34
+
35
+
36
+ # Return the overload packet for the edge op
37
+ def get_edge_overload_packet (edge_op : EdgeOpOverload ) -> EdgeOpOverloadPacket :
38
+ edge_op_namespace , edge_op_name = (
39
+ edge_op .namespace ,
40
+ edge_op ._schema .name .split ("::" )[1 ],
41
+ )
42
+ edge_op_overload_packet = getattr (
43
+ getattr (exir_ops .edge , edge_op_namespace ), edge_op_name
44
+ )
45
+ return edge_op_overload_packet
46
+
47
+
48
+ # Get the frequency list of ops in a graph module
49
+ def get_ops_count (graph_module : torch .fx .GraphModule ) -> Dict [str , int ]:
50
+ freq = {}
51
+ # Loop over nodes to count the number of times each op occurs
52
+ for node in graph_module .graph .nodes :
53
+ if node .op == "call_function" :
54
+ # Ignore getitem, alloc and view cases, we only want actual operations
55
+ if (
56
+ node .target == operator .getitem
57
+ or node .target .__name__ == "alloc"
58
+ or node .target == memory .view
59
+ ):
60
+ continue
61
+ # If the op is already present, increment the count
62
+ if get_edge_overload_packet (node .target ).__name__ in freq :
63
+ freq [get_edge_overload_packet (node .target ).__name__ ] += 1
64
+ # else, add a new entry
65
+ else :
66
+ freq [get_edge_overload_packet (node .target ).__name__ ] = 1
67
+ return freq
68
+
69
+
70
+ # Print the ops and how many times they occur multiple graph modules:
71
+ # from export, from to_edge, and from Jarvis. Print the available
72
+ # implementations for each op, and error out if the op is not supported.
73
+ def print_ops_info (
74
+ export_gm : torch .fx .GraphModule ,
75
+ to_edge_gm : torch .fx .GraphModule ,
76
+ jarvis_gm : torch .fx .GraphModule ,
77
+ ):
78
+ export_ops_count = get_ops_count (export_gm )
79
+ to_edge_ops_count = get_ops_count (to_edge_gm )
80
+ jarvis_ops_count = get_ops_count (jarvis_gm )
81
+
82
+ # De-duplicate the "<op>" and "<op>_copy" ops
83
+ keys_to_delete_and_add = []
84
+ for k1 in export_ops_count :
85
+ for k2 in {** to_edge_ops_count , ** jarvis_ops_count }:
86
+ if k2 .startswith (k1 ):
87
+ keys_to_delete_and_add .append ((k1 , k2 ))
88
+ break
89
+
90
+ for k in keys_to_delete_and_add :
91
+ export_ops_count [k [1 ]] = export_ops_count [k [0 ]]
92
+ del export_ops_count [k [0 ]]
93
+
94
+ removed_ops = []
95
+ # Get the counts of the ops that are removed from the final graph
96
+ for k in {** export_ops_count , ** to_edge_ops_count }:
97
+ if k not in jarvis_ops_count :
98
+ removed_ops .append (k )
99
+
100
+ # Create a dict of ops and their counts to pass to tabulate
101
+ ops_count = [
102
+ [
103
+ op ,
104
+ jarvis_ops_count [op ],
105
+ to_edge_ops_count [op ] if op in to_edge_ops_count else 0 ,
106
+ export_ops_count [op ] if op in export_ops_count else 0 ,
107
+ ]
108
+ for op in jarvis_ops_count
109
+ ]
110
+ sorted_ops_count = sorted (ops_count , key = lambda x : x [1 ], reverse = True )
111
+
112
+ # Create a dict of deleted ops and their counts to pass to tabulate
113
+ removed_ops_count = [
114
+ [
115
+ op ,
116
+ 0 ,
117
+ to_edge_ops_count [op ] if op in to_edge_ops_count else 0 ,
118
+ export_ops_count [op ] if op in export_ops_count else 0 ,
119
+ ]
120
+ for op in removed_ops
121
+ ]
122
+
123
+ # Print the final ops and their counts in a tabular format
124
+ logging .info (
125
+ tabulate (
126
+ sorted_ops_count ,
127
+ headers = [
128
+ "Final Operators " , # one character longer than the longest op name
129
+ "Jarvis (Final) Graph" ,
130
+ "To_edge Graph" ,
131
+ "Export Graph" ,
132
+ ],
133
+ tablefmt = "outline" ,
134
+ )
135
+ )
136
+
137
+ # Print the removed ops and their counts in a tabular format (if any)
138
+ if removed_ops != []:
139
+ logging .info (
140
+ tabulate (
141
+ removed_ops_count ,
142
+ headers = [
143
+ "Deleted Operators " , # one character longer than the longest op name
144
+ "Jarvis (Final) Graph" ,
145
+ "To_edge Graph" ,
146
+ "Export Graph" ,
147
+ ],
148
+ tablefmt = "outline" ,
149
+ )
150
+ )
0 commit comments