Skip to content

Commit 5f906b2

Browse files
Olivia-liufacebook-github-bot
authored andcommitted
New util function for getting delegation summary (#2611)
Summary: Because what happened to the graph in to_backend() is pretty opaque to the users, the existing [print_delegated_graph()](https://fburl.com/code/4xa5oewv) is good but might be too much for new users, so implemented this new util function to give a more lightweight summary for the delegated graph. Original design doc: https://docs.google.com/document/d/19ZSDddm23MnGvFUrkV9clwHwLvXzAo2IZtbkr9VChjE/edit?usp=sharing **Note:** The added `import pandas` shouldn't be an issue because we already have pandas in the install requirement is executorch oss: https://fburl.com/code/p55654r2 Differential Revision: D55239751
1 parent a3bf63b commit 5f906b2

File tree

3 files changed

+232
-0
lines changed

3 files changed

+232
-0
lines changed

exir/backend/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ python_unittest(
269269
"test_utils.py",
270270
],
271271
deps = [
272+
"fbsource//third-party/pypi/pandas:pandas",
272273
":op_partitioner_demo",
273274
"//caffe2:torch",
274275
"//executorch/exir:lib",

exir/backend/test/test_utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@
66

77
import unittest
88

9+
import pandas as pd
10+
911
import torch
1012
from executorch import exir
1113
from executorch.exir import CaptureConfig, to_edge
1214
from executorch.exir.backend.backend_api import to_backend
1315
from executorch.exir.backend.partitioner import Partitioner, PartitionResult
1416
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
1517
from executorch.exir.backend.utils import (
18+
DelegationBreakdown,
1619
get_delegates,
20+
get_delegation_info,
1721
get_non_lowered_nodes,
1822
is_identical_graph,
1923
print_delegated_graph,
@@ -22,6 +26,7 @@
2226
)
2327

2428
from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops
29+
from pandas.testing import assert_frame_equal
2530
from torch.ao.quantization import get_default_qconfig # @manual
2631
from torch.ao.quantization.backend_config.executorch import (
2732
get_executorch_backend_config,
@@ -439,3 +444,65 @@ def forward(self, a, x, b):
439444
graph_str,
440445
"Expect to see the aten.mm in the delegated graph",
441446
)
447+
448+
def test_get_delegation_info(self):
449+
class Model(torch.nn.Module):
450+
def __init__(self):
451+
super().__init__()
452+
453+
def forward(self, a, x, b):
454+
y = torch.mm(a, x)
455+
z = y + b
456+
a = z - a
457+
y = torch.mm(a, x)
458+
z = y + b
459+
return z
460+
461+
m = Model()
462+
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
463+
edge = to_edge(torch.export.export(m, inputs)).to_backend(
464+
AddMulPartitionerDemo()
465+
)
466+
delegation_info = get_delegation_info(edge.exported_program().graph_module)
467+
468+
self.assertEqual(delegation_info.num_delegated_subgraphs, 2)
469+
self.assertEqual(delegation_info.num_delegated_nodes, 4)
470+
self.assertEqual(delegation_info.num_non_delegated_nodes, 3)
471+
expected_delegation_by_op_dict = {
472+
"aten_add_tensor": DelegationBreakdown(
473+
op_type="aten_add_tensor", delegated=2, non_delegated=0
474+
),
475+
"aten_mm_default": DelegationBreakdown(
476+
op_type="aten_mm_default", delegated=2, non_delegated=0
477+
),
478+
"aten_sub_tensor": DelegationBreakdown(
479+
op_type="aten_sub_tensor", delegated=0, non_delegated=1
480+
),
481+
"getitem": DelegationBreakdown(
482+
op_type="getitem", delegated=0, non_delegated=2
483+
),
484+
}
485+
self.assertEqual(
486+
delegation_info.delegation_by_operator, expected_delegation_by_op_dict
487+
)
488+
489+
self.assertIn(
490+
"Total delegated subgraphs",
491+
delegation_info.get_summary(),
492+
)
493+
494+
df = delegation_info.get_operator_delegation_dataframe()
495+
expected_df = pd.DataFrame(
496+
{
497+
"op_type": [
498+
"aten_add_tensor",
499+
"aten_mm_default",
500+
"aten_sub_tensor",
501+
"getitem",
502+
"Total",
503+
],
504+
"occurrences_in_delegated_graphs": [2, 2, 0, 0, 4],
505+
"occurrences_in_non_delegated_graphs": [0, 0, 1, 2, 3],
506+
}
507+
)
508+
assert_frame_equal(expected_df, df)

exir/backend/utils.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66

77
import logging
88
import operator
9+
import re
910
from collections import defaultdict
11+
from dataclasses import asdict, dataclass
1012
from functools import lru_cache
1113
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
1214

15+
import pandas as pd
1316
import torch
1417
from executorch.exir.backend.backend_details import ExportedProgram
1518
from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
@@ -27,6 +30,10 @@
2730
T_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
2831
T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
2932

33+
OCCURRENCES_IN_DELEGATED_GRAPHS = "occurrences_in_delegated_graphs"
34+
OCCURRENCES_IN_NON_DELEGATED_GRAPHS = "occurrences_in_non_delegated_graphs"
35+
36+
3037
log: logging.Logger = logging.getLogger(__name__)
3138

3239

@@ -280,6 +287,163 @@ def get_delegates(graph: torch.fx.Graph) -> List[torch.fx.Node]:
280287
]
281288

282289

290+
@dataclass
291+
class DelegationBreakdown:
292+
"""
293+
DelegationBreakdown contains the number of delegated and non-delegated nodes
294+
of the operator type op_type.
295+
296+
Args:
297+
delegated: The number of delegated nodes.
298+
non_delegated: The number of non-delegated nodes.
299+
"""
300+
301+
op_type: str = ""
302+
delegated: int = 0
303+
non_delegated: int = 0
304+
305+
306+
@dataclass
307+
class DelegationInfo:
308+
"""
309+
DelegationInfo contains information of a delegated graph module.
310+
311+
Args:
312+
num_delegated_subgraphs: The number of delegated subgraphs.
313+
num_delegated_nodes: The number of delegated nodes.
314+
num_non_delegated_nodes: The number of non-delegated nodes.
315+
delegation_by_operator: A dictionary of operator type to DelegationBreakdown.
316+
"""
317+
318+
num_delegated_subgraphs: int
319+
num_delegated_nodes: int
320+
num_non_delegated_nodes: int
321+
delegation_by_operator: Dict[str, DelegationBreakdown]
322+
323+
def get_summary(self) -> str:
324+
"""
325+
Get a summary of the delegation information in string format.
326+
327+
Args:
328+
None
329+
330+
Returns:
331+
A string containing information of some class attributes for easy print-out.
332+
"""
333+
334+
# Assemble and return the summary string
335+
summary_str = f"Total delegated subgraphs: {self.num_delegated_subgraphs}\n"
336+
summary_str += f"Number of delegated nodes: {self.num_delegated_nodes}\n"
337+
summary_str += (
338+
f"Number of non-delegated nodes: {self.num_non_delegated_nodes}\n"
339+
)
340+
return summary_str
341+
342+
def get_operator_delegation_dataframe(self) -> pd.DataFrame:
343+
"""
344+
Get the delegation information grouped by operator type in a pandas DataFrame.
345+
346+
Args:
347+
None
348+
349+
Returns:
350+
Returns a pandas DataFrame containing the following columns:
351+
- op_type: The operator type, with the last row being "Total".
352+
- occurrences_in_delegated_graphs: The number of occurrences of the op_type in delegated subgraphs.
353+
- occurrences_in_non_delegated_graphs: The number of occurrences of the op_type not in delegated subgraphs.
354+
With the last row being the total number of delegated and non-delegated occurrences of each op_type.
355+
"""
356+
357+
# Convert the dict to a dataframe
358+
list_of_dicts = [
359+
asdict(breakdown) for breakdown in self.delegation_by_operator.values()
360+
]
361+
df = pd.DataFrame(list_of_dicts)
362+
# Rename columns for better understandability
363+
df = df.rename(
364+
columns={
365+
"delegated": OCCURRENCES_IN_DELEGATED_GRAPHS,
366+
"non_delegated": OCCURRENCES_IN_NON_DELEGATED_GRAPHS,
367+
}
368+
)
369+
df = df.sort_values(by="op_type", ignore_index=True)
370+
371+
# Add a Total row at the bottom
372+
total_delegated_nodes = df[OCCURRENCES_IN_DELEGATED_GRAPHS].sum()
373+
total_non_delegated_nodes = df[OCCURRENCES_IN_NON_DELEGATED_GRAPHS].sum()
374+
df.loc[len(df)] = ["Total", total_delegated_nodes, total_non_delegated_nodes]
375+
376+
return df
377+
378+
379+
def get_delegation_info(
380+
graph_module: torch.fx.GraphModule,
381+
) -> DelegationInfo:
382+
"""
383+
Util function to get the delegation information of the given graph module.
384+
385+
Args:
386+
graph_module: The lowered graph module to get the delegation information from.
387+
388+
Returns:
389+
Return a DelegationInfo object containing the delegation information.
390+
"""
391+
392+
def _get_op_type(node_name: str) -> str:
393+
# node_name is in format <op_type> or <op_type>_x in which x is an integer suffix.
394+
return re.sub(r"_[\d]+$", "", node_name)
395+
396+
op_occurrences_dict = defaultdict(lambda: DelegationBreakdown())
397+
398+
def _insert_op_occurrences_dict(node_name: str, delegated: bool) -> None:
399+
op_type = _get_op_type(node_name)
400+
op_occurrences_dict[op_type].op_type = op_type
401+
if delegated:
402+
op_occurrences_dict[op_type].delegated += 1
403+
else:
404+
op_occurrences_dict[op_type].non_delegated += 1
405+
406+
delegated_subgraph_counter = 0
407+
408+
lowered_module_dict = {
409+
node.name: getattr(graph_module, node.name)
410+
for node in graph_module.graph.nodes
411+
if node.op == "get_attr" and node.name.startswith("lowered_module_")
412+
}
413+
414+
for node in graph_module.graph.nodes:
415+
if (
416+
node.op == "call_function"
417+
and _get_op_type(node.name) != "executorch_call_delegate"
418+
):
419+
# Non-delegated node
420+
_insert_op_occurrences_dict(node_name=node.name, delegated=False)
421+
# Check if the node is a lowered module
422+
if node.op == "get_attr" and node.name.startswith("lowered_module_"):
423+
lowered_module = lowered_module_dict[node.name]
424+
delegated_subgraph_counter += 1
425+
for node_in_lowered_module in lowered_module.original_module.graph.nodes:
426+
if node_in_lowered_module.op == "call_function":
427+
# Delegated node
428+
_insert_op_occurrences_dict(
429+
node_name=node_in_lowered_module.name, delegated=True
430+
)
431+
432+
# Calculate the total number of delegated and non-delegated nodes
433+
num_delegated_nodes = 0
434+
num_non_delegated_nodes = 0
435+
for value in op_occurrences_dict.values():
436+
num_delegated_nodes += value.delegated
437+
num_non_delegated_nodes += value.non_delegated
438+
439+
return DelegationInfo(
440+
num_delegated_nodes=num_delegated_nodes,
441+
num_non_delegated_nodes=num_non_delegated_nodes,
442+
num_delegated_subgraphs=delegated_subgraph_counter,
443+
delegation_by_operator=op_occurrences_dict,
444+
)
445+
446+
283447
def print_delegated_graph(graph_module: torch.fx.GraphModule) -> str:
284448
"""
285449
Print the graph of including lowered_module (both backend id and original graph) together with the graph module. Example output:

0 commit comments

Comments
 (0)