Skip to content

Commit 93a7f24

Browse files
Olivia-liufacebook-github-bot
authored andcommitted
New util function to print delegation summary
Summary: Because what happened to the graph in to_backend() is pretty opaque to the users, the existing [print_delegated_graph()](https://fburl.com/code/4xa5oewv) is good but might be too much for new users, so implemented this new util function to give a more lightweight summary for the delegated graph. Original design doc: https://docs.google.com/document/d/19ZSDddm23MnGvFUrkV9clwHwLvXzAo2IZtbkr9VChjE/edit?usp=sharing **Note:** The added `import pandas` shouldn't be an issue because we already have pandas in the install requirement is executorch oss: https://fburl.com/code/p55654r2 Differential Revision: D55239751
1 parent a9dc341 commit 93a7f24

File tree

2 files changed

+172
-0
lines changed

2 files changed

+172
-0
lines changed

exir/backend/test/test_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
1515
from executorch.exir.backend.utils import (
1616
get_delegates,
17+
get_delegation_info,
1718
get_non_lowered_nodes,
1819
is_identical_graph,
1920
print_delegated_graph,
@@ -439,3 +440,47 @@ def forward(self, a, x, b):
439440
graph_str,
440441
"Expect to see the aten.mm in the delegated graph",
441442
)
443+
444+
def test_get_delegation_info(self):
445+
class Model(torch.nn.Module):
446+
def __init__(self):
447+
super().__init__()
448+
449+
def forward(self, a, x, b):
450+
y = torch.mm(a, x)
451+
z = y + b
452+
a = z - a
453+
y = torch.mm(a, x)
454+
z = y + b
455+
return z
456+
457+
m = Model()
458+
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
459+
460+
edge = to_edge(torch.export.export(m, inputs)).to_backend(
461+
AddMulPartitionerDemo()
462+
)
463+
464+
delegation_info = get_delegation_info(edge.exported_program().graph_module)
465+
self.assertEqual(delegation_info.num_delegated_nodes, 4)
466+
self.assertIn(
467+
"Total delegated subgraphs",
468+
delegation_info.get_summary(),
469+
)
470+
df = delegation_info.get_operator_delegation_dataframe()
471+
self.assertFalse(df.empty)
472+
integer_columns = [
473+
"occurrences in delegated graphs",
474+
"occurrences in non-delegated graphs",
475+
]
476+
has_positive = False
477+
for column in integer_columns:
478+
# Check if there is at least one positive number in the column
479+
if df[column].gt(0).any():
480+
has_positive = True
481+
break
482+
483+
self.assertTrue(
484+
has_positive,
485+
f"The column '{column}' does not contain any positive numbers",
486+
)

exir/backend/utils.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66

77
import logging
88
import operator
9+
import re
910
from collections import defaultdict
11+
from dataclasses import dataclass
1012
from functools import lru_cache
1113
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
1214

15+
import pandas as pd
1316
import torch
1417
from executorch.exir.backend.backend_details import ExportedProgram
1518
from executorch.exir.common import setting_python_recursive_limit
@@ -211,6 +214,130 @@ def get_delegates(graph: torch.fx.Graph) -> List[torch.fx.Node]:
211214
]
212215

213216

217+
@dataclass
218+
class DelegationBreakdown:
219+
delegated: int
220+
non_delegated: int
221+
222+
223+
@dataclass
224+
class DelegationInfo:
225+
num_delegated_subgraphs: int
226+
num_delegated_nodes: int
227+
num_non_delegated_nodes: int
228+
delegation_by_operator: Dict[str, DelegationBreakdown]
229+
230+
def get_summary(self) -> str:
231+
# Assemble and return the summary string
232+
summary_str = f"Total delegated subgraphs: {self.num_delegated_subgraphs}\n"
233+
summary_str += f"Number of delegated nodes: {self.num_delegated_nodes}\n"
234+
summary_str += (
235+
f"Number of non-delegated nodes: {self.num_non_delegated_nodes}\n"
236+
)
237+
return summary_str
238+
239+
def get_operator_delegation_dataframe(self) -> pd.DataFrame:
240+
# Convert the dict to a dataframe
241+
df = pd.DataFrame(
242+
list(self.delegation_by_operator.items()),
243+
columns=["op type", "occurrences"],
244+
)
245+
246+
# Function to extract delegated and non-delegated fields
247+
def _extract_breakdown(row):
248+
return pd.Series(
249+
[row["occurrences"].delegated, row["occurrences"].non_delegated]
250+
)
251+
252+
df[
253+
["occurrences in delegated graphs", "occurrences in non-delegated graphs"]
254+
] = df.apply(lambda row: _extract_breakdown(row), axis=1)
255+
df.drop(columns=["occurrences"], inplace=True)
256+
257+
# Add a Total row at the bottom
258+
total_delegated_nodes = df["occurrences in delegated graphs"].sum()
259+
total_non_delegated_nodes = df["occurrences in non-delegated graphs"].sum()
260+
df = df.sort_values(by="op type", ignore_index=True)
261+
df.loc["Total"] = ["Total", total_delegated_nodes, total_non_delegated_nodes]
262+
263+
return df
264+
265+
266+
def _get_op_type(node_name: str) -> str:
267+
# node_name is in format <op_type> or <op_type>_x in which x is an integer suffix.
268+
pattern = r"_(\d+)$"
269+
match = re.search(pattern, node_name)
270+
if match:
271+
return node_name[: match.start()]
272+
else:
273+
return node_name
274+
275+
276+
def get_delegation_info(
277+
graph_module: torch.fx.GraphModule,
278+
) -> DelegationInfo:
279+
"""
280+
Returns a string (for high level summary) and a DataFrame (for per op type summary) of the
281+
delegation info on the given graph module.
282+
"""
283+
284+
op_occurrences_dict = {}
285+
286+
def _insert_op_occurrences_dict(node_name: str, delegated: bool) -> None:
287+
op_type = _get_op_type(node_name)
288+
if op_type not in op_occurrences_dict:
289+
if delegated:
290+
op_occurrences_dict[op_type] = DelegationBreakdown(
291+
delegated=1, non_delegated=0
292+
)
293+
else:
294+
op_occurrences_dict[op_type] = DelegationBreakdown(
295+
delegated=0, non_delegated=1
296+
)
297+
else:
298+
if delegated:
299+
op_occurrences_dict[op_type].delegated += 1
300+
else:
301+
op_occurrences_dict[op_type].non_delegated += 1
302+
303+
delegated_subgraph_counter = 0
304+
305+
lowered_module_dict = {
306+
node.name: getattr(graph_module, node.name)
307+
for node in graph_module.graph.nodes
308+
if node.op == "get_attr" and node.name.startswith("lowered_module_")
309+
}
310+
311+
for node in graph_module.graph.nodes:
312+
if (
313+
node.op == "call_function"
314+
and _get_op_type(node.name) != "executorch_call_delegate"
315+
):
316+
_insert_op_occurrences_dict(node_name=node.name, delegated=False)
317+
if node.op == "get_attr" and node.name.startswith("lowered_module_"):
318+
lowered_module = lowered_module_dict[node.name]
319+
delegated_subgraph_counter += 1
320+
for node_in_lowered_module in lowered_module.original_module.graph.nodes:
321+
if node_in_lowered_module.op == "call_function":
322+
_insert_op_occurrences_dict(
323+
node_name=node_in_lowered_module.name, delegated=True
324+
)
325+
326+
# Calculate the total number of delegated and non-delegated nodes
327+
num_delegated_nodes = 0
328+
num_non_delegated_nodes = 0
329+
for _, value in op_occurrences_dict.items():
330+
num_delegated_nodes += value.delegated
331+
num_non_delegated_nodes += value.non_delegated
332+
333+
return DelegationInfo(
334+
num_delegated_nodes=num_delegated_nodes,
335+
num_non_delegated_nodes=num_non_delegated_nodes,
336+
num_delegated_subgraphs=delegated_subgraph_counter,
337+
delegation_by_operator=op_occurrences_dict,
338+
)
339+
340+
214341
def print_delegated_graph(graph_module: torch.fx.GraphModule) -> str:
215342
"""
216343
Print the graph of including lowered_module (both backend id and original graph) together with the graph module. Example output:

0 commit comments

Comments
 (0)