Skip to content

Commit 5e2d830

Browse files
Olivia-liufacebook-github-bot
authored andcommitted
New util function for getting delegation summary (#2611)
Summary: Because what happened to the graph in to_backend() is pretty opaque to the users, the existing [print_delegated_graph()](https://fburl.com/code/4xa5oewv) is good but might be too much for new users, so implemented this new util function to give a more lightweight summary for the delegated graph. Original design doc: https://docs.google.com/document/d/19ZSDddm23MnGvFUrkV9clwHwLvXzAo2IZtbkr9VChjE/edit?usp=sharing **Note:** The added `import pandas` shouldn't be an issue because we already have pandas in the install requirement is executorch oss: https://fburl.com/code/p55654r2 Differential Revision: D55239751
1 parent 67a7d20 commit 5e2d830

File tree

3 files changed

+229
-0
lines changed

3 files changed

+229
-0
lines changed

exir/backend/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ python_unittest(
267267
"test_utils.py",
268268
],
269269
deps = [
270+
"fbsource//third-party/pypi/pandas:pandas",
270271
":op_partitioner_demo",
271272
"//caffe2:torch",
272273
"//executorch/exir:lib",

exir/backend/test/test_utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@
66

77
import unittest
88

9+
import pandas as pd
10+
911
import torch
1012
from executorch import exir
1113
from executorch.exir import CaptureConfig, to_edge
1214
from executorch.exir.backend.backend_api import to_backend
1315
from executorch.exir.backend.partitioner import Partitioner, PartitionResult
1416
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
1517
from executorch.exir.backend.utils import (
18+
DelegationBreakdown,
1619
get_delegates,
20+
get_delegation_info,
1721
get_non_lowered_nodes,
1822
is_identical_graph,
1923
print_delegated_graph,
@@ -22,6 +26,7 @@
2226
)
2327

2428
from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops
29+
from pandas.testing import assert_frame_equal
2530
from torch.ao.quantization import get_default_qconfig # @manual
2631
from torch.ao.quantization.backend_config.executorch import (
2732
get_executorch_backend_config,
@@ -439,3 +444,57 @@ def forward(self, a, x, b):
439444
graph_str,
440445
"Expect to see the aten.mm in the delegated graph",
441446
)
447+
448+
def test_get_delegation_info(self):
449+
class Model(torch.nn.Module):
450+
def __init__(self):
451+
super().__init__()
452+
453+
def forward(self, a, x, b):
454+
y = torch.mm(a, x)
455+
z = y + b
456+
a = z - a
457+
y = torch.mm(a, x)
458+
z = y + b
459+
return z
460+
461+
m = Model()
462+
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
463+
edge = to_edge(torch.export.export(m, inputs)).to_backend(
464+
AddMulPartitionerDemo()
465+
)
466+
delegation_info = get_delegation_info(edge.exported_program().graph_module)
467+
468+
self.assertEqual(delegation_info.num_delegated_subgraphs, 2)
469+
self.assertEqual(delegation_info.num_delegated_nodes, 4)
470+
self.assertEqual(delegation_info.num_non_delegated_nodes, 3)
471+
expected_delegation_by_op_dict = {
472+
"aten_add_tensor": DelegationBreakdown(delegated=2, non_delegated=0),
473+
"aten_mm_default": DelegationBreakdown(delegated=2, non_delegated=0),
474+
"aten_sub_tensor": DelegationBreakdown(delegated=0, non_delegated=1),
475+
"getitem": DelegationBreakdown(delegated=0, non_delegated=2),
476+
}
477+
self.assertEqual(
478+
delegation_info.delegation_by_operator, expected_delegation_by_op_dict
479+
)
480+
481+
self.assertIn(
482+
"Total delegated subgraphs",
483+
delegation_info.get_summary(),
484+
)
485+
486+
df = delegation_info.get_operator_delegation_dataframe()
487+
expected_df = pd.DataFrame(
488+
{
489+
"op_type": [
490+
"aten_add_tensor",
491+
"aten_mm_default",
492+
"aten_sub_tensor",
493+
"getitem",
494+
"Total",
495+
],
496+
"occurrences_in_delegated_graphs": [2, 2, 0, 0, 4],
497+
"occurrences_in_non_delegated_graphs": [0, 0, 1, 2, 3],
498+
}
499+
)
500+
assert_frame_equal(expected_df, df)

exir/backend/utils.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66

77
import logging
88
import operator
9+
import re
910
from collections import defaultdict
11+
from dataclasses import dataclass
1012
from functools import lru_cache
1113
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
1214

15+
import pandas as pd
1316
import torch
1417
from executorch.exir.backend.backend_details import ExportedProgram
1518
from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
@@ -280,6 +283,172 @@ def get_delegates(graph: torch.fx.Graph) -> List[torch.fx.Node]:
280283
]
281284

282285

286+
@dataclass
287+
class DelegationBreakdown:
288+
"""
289+
DelegationBreakdown contains for the number of delegated nodes and non-delegated nodes
290+
of an operator type.
291+
292+
Args:
293+
delegated: The number of delegated nodes.
294+
non_delegated: The number of non-delegated nodes.
295+
"""
296+
297+
delegated: int
298+
non_delegated: int
299+
300+
301+
@dataclass
302+
class DelegationInfo:
303+
"""
304+
DelegationInfo contains information of a delegated graph module.
305+
306+
Args:
307+
num_delegated_subgraphs: The number of delegated subgraphs.
308+
num_delegated_nodes: The number of delegated nodes.
309+
num_non_delegated_nodes: The number of non-delegated nodes.
310+
delegation_by_operator: A dictionary of operator type to DelegationBreakdown.
311+
"""
312+
313+
num_delegated_subgraphs: int
314+
num_delegated_nodes: int
315+
num_non_delegated_nodes: int
316+
delegation_by_operator: Dict[str, DelegationBreakdown]
317+
318+
def get_summary(self) -> str:
319+
"""
320+
Get a summary of the delegation information in string format.
321+
322+
Args:
323+
None
324+
325+
Returns:
326+
A string containing information of some class attributes for easy print-out.
327+
"""
328+
329+
# Assemble and return the summary string
330+
summary_str = f"Total delegated subgraphs: {self.num_delegated_subgraphs}\n"
331+
summary_str += f"Number of delegated nodes: {self.num_delegated_nodes}\n"
332+
summary_str += (
333+
f"Number of non-delegated nodes: {self.num_non_delegated_nodes}\n"
334+
)
335+
return summary_str
336+
337+
def get_operator_delegation_dataframe(self) -> pd.DataFrame:
338+
"""
339+
Get the delegation information grouped by operator type in a pandas DataFrame.
340+
341+
Args:
342+
None
343+
344+
Returns:
345+
Returns a pandas DataFrame containing the following columns:
346+
- op_type: The operator type, with the last row being "Total".
347+
- occurrences_in_delegated_graphs: The number of occurrences of the op_type in delegated subgraphs.
348+
- occurrences_in_non_delegated_graphs: The number of occurrences of the op_type not in delegated subgraphs.
349+
With the last row being the total number of delegated and non-delegated occurrences of each op_type.
350+
"""
351+
352+
# Convert the dict to a dataframe
353+
df = pd.DataFrame(
354+
list(self.delegation_by_operator.items()),
355+
columns=["op_type", "occurrences"],
356+
)
357+
358+
# Function to extract delegated and non-delegated fields
359+
def _extract_breakdown(row):
360+
return pd.Series(
361+
[row["occurrences"].delegated, row["occurrences"].non_delegated]
362+
)
363+
364+
df[
365+
["occurrences_in_delegated_graphs", "occurrences_in_non_delegated_graphs"]
366+
] = df.apply(lambda row: _extract_breakdown(row), axis=1)
367+
df.drop(columns=["occurrences"], inplace=True)
368+
df = df.sort_values(by="op_type", ignore_index=True)
369+
370+
# Add a Total row at the bottom
371+
total_delegated_nodes = df["occurrences_in_delegated_graphs"].sum()
372+
total_non_delegated_nodes = df["occurrences_in_non_delegated_graphs"].sum()
373+
df.loc[len(df)] = ["Total", total_delegated_nodes, total_non_delegated_nodes]
374+
375+
return df
376+
377+
378+
def get_delegation_info(
379+
graph_module: torch.fx.GraphModule,
380+
) -> DelegationInfo:
381+
"""
382+
Util function to get the delegation information of the given graph module.
383+
384+
Args:
385+
graph_module: The lowered graph module to get the delegation information from.
386+
387+
Returns:
388+
Return a DelegationInfo object containing the delegation information.
389+
"""
390+
391+
def _get_op_type(node_name: str) -> str:
392+
# node_name is in format <op_type> or <op_type>_x in which x is an integer suffix.
393+
return re.sub(r"_[\d]+$", "", node_name)
394+
395+
op_occurrences_dict = {}
396+
397+
def _insert_op_occurrences_dict(node_name: str, delegated: bool) -> None:
398+
op_type = _get_op_type(node_name)
399+
if op_type not in op_occurrences_dict:
400+
if delegated:
401+
op_occurrences_dict[op_type] = DelegationBreakdown(
402+
delegated=1, non_delegated=0
403+
)
404+
else:
405+
op_occurrences_dict[op_type] = DelegationBreakdown(
406+
delegated=0, non_delegated=1
407+
)
408+
else:
409+
if delegated:
410+
op_occurrences_dict[op_type].delegated += 1
411+
else:
412+
op_occurrences_dict[op_type].non_delegated += 1
413+
414+
delegated_subgraph_counter = 0
415+
416+
lowered_module_dict = {
417+
node.name: getattr(graph_module, node.name)
418+
for node in graph_module.graph.nodes
419+
if node.op == "get_attr" and node.name.startswith("lowered_module_")
420+
}
421+
422+
for node in graph_module.graph.nodes:
423+
if (
424+
node.op == "call_function"
425+
and _get_op_type(node.name) != "executorch_call_delegate"
426+
):
427+
_insert_op_occurrences_dict(node_name=node.name, delegated=False)
428+
if node.op == "get_attr" and node.name.startswith("lowered_module_"):
429+
lowered_module = lowered_module_dict[node.name]
430+
delegated_subgraph_counter += 1
431+
for node_in_lowered_module in lowered_module.original_module.graph.nodes:
432+
if node_in_lowered_module.op == "call_function":
433+
_insert_op_occurrences_dict(
434+
node_name=node_in_lowered_module.name, delegated=True
435+
)
436+
437+
# Calculate the total number of delegated and non-delegated nodes
438+
num_delegated_nodes = 0
439+
num_non_delegated_nodes = 0
440+
for _, value in op_occurrences_dict.items():
441+
num_delegated_nodes += value.delegated
442+
num_non_delegated_nodes += value.non_delegated
443+
444+
return DelegationInfo(
445+
num_delegated_nodes=num_delegated_nodes,
446+
num_non_delegated_nodes=num_non_delegated_nodes,
447+
num_delegated_subgraphs=delegated_subgraph_counter,
448+
delegation_by_operator=op_occurrences_dict,
449+
)
450+
451+
283452
def print_delegated_graph(graph_module: torch.fx.GraphModule) -> str:
284453
"""
285454
Print the graph of including lowered_module (both backend id and original graph) together with the graph module. Example output:

0 commit comments

Comments
 (0)