Skip to content

Commit a19a32b

Browse files
Olivia-liufacebook-github-bot
authored andcommitted
New util function for getting delegation summary (#2611)
Summary: bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Pull Request resolved: #2611 Because what happened to the graph in to_backend() is pretty opaque to the users, the existing [print_delegated_graph()](https://fburl.com/code/4xa5oewv) is good but might be too much for new users, so implemented this new util function to give a more lightweight summary for the delegated graph. Original design doc: https://docs.google.com/document/d/19ZSDddm23MnGvFUrkV9clwHwLvXzAo2IZtbkr9VChjE/edit?usp=sharing **Note:** The added `import pandas` shouldn't be an issue because we already have pandas in the install requirement is executorch oss: https://fburl.com/code/p55654r2 Reviewed By: dbort Differential Revision: D55239751 fbshipit-source-id: 81000953b51d87f6d509a6589685fe677a8eec5b
1 parent 66c5fc8 commit a19a32b

File tree

4 files changed

+235
-0
lines changed

4 files changed

+235
-0
lines changed

exir/backend/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ runtime.python_library(
106106
"@EXECUTORCH_CLIENTS",
107107
],
108108
deps = [
109+
"fbsource//third-party/pypi/pandas:pandas",
109110
"//caffe2:torch",
110111
"//executorch/exir:lowered_backend_module",
111112
],

exir/backend/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ python_unittest(
269269
"test_utils.py",
270270
],
271271
deps = [
272+
"fbsource//third-party/pypi/pandas:pandas",
272273
":op_partitioner_demo",
273274
"//caffe2:torch",
274275
"//executorch/exir:lib",

exir/backend/test/test_utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@
66

77
import unittest
88

9+
import pandas as pd
10+
911
import torch
1012
from executorch import exir
1113
from executorch.exir import CaptureConfig, to_edge
1214
from executorch.exir.backend.backend_api import to_backend
1315
from executorch.exir.backend.partitioner import Partitioner, PartitionResult
1416
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
1517
from executorch.exir.backend.utils import (
18+
DelegationBreakdown,
1619
get_delegates,
20+
get_delegation_info,
1721
get_non_lowered_nodes,
1822
is_identical_graph,
1923
print_delegated_graph,
@@ -22,6 +26,7 @@
2226
)
2327

2428
from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops
29+
from pandas.testing import assert_frame_equal
2530
from torch.ao.quantization import get_default_qconfig # @manual
2631
from torch.ao.quantization.backend_config.executorch import (
2732
get_executorch_backend_config,
@@ -439,3 +444,65 @@ def forward(self, a, x, b):
439444
graph_str,
440445
"Expect to see the aten.mm in the delegated graph",
441446
)
447+
448+
def test_get_delegation_info(self):
449+
class Model(torch.nn.Module):
450+
def __init__(self):
451+
super().__init__()
452+
453+
def forward(self, a, x, b):
454+
y = torch.mm(a, x)
455+
z = y + b
456+
a = z - a
457+
y = torch.mm(a, x)
458+
z = y + b
459+
return z
460+
461+
m = Model()
462+
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
463+
edge = to_edge(torch.export.export(m, inputs)).to_backend(
464+
AddMulPartitionerDemo()
465+
)
466+
delegation_info = get_delegation_info(edge.exported_program().graph_module)
467+
468+
self.assertEqual(delegation_info.num_delegated_subgraphs, 2)
469+
self.assertEqual(delegation_info.num_delegated_nodes, 4)
470+
self.assertEqual(delegation_info.num_non_delegated_nodes, 3)
471+
expected_delegation_by_op_dict = {
472+
"aten_add_tensor": DelegationBreakdown(
473+
op_type="aten_add_tensor", delegated=2, non_delegated=0
474+
),
475+
"aten_mm_default": DelegationBreakdown(
476+
op_type="aten_mm_default", delegated=2, non_delegated=0
477+
),
478+
"aten_sub_tensor": DelegationBreakdown(
479+
op_type="aten_sub_tensor", delegated=0, non_delegated=1
480+
),
481+
"getitem": DelegationBreakdown(
482+
op_type="getitem", delegated=0, non_delegated=2
483+
),
484+
}
485+
self.assertEqual(
486+
delegation_info.delegation_by_operator, expected_delegation_by_op_dict
487+
)
488+
489+
self.assertIn(
490+
"Total delegated subgraphs",
491+
delegation_info.get_summary(),
492+
)
493+
494+
df = delegation_info.get_operator_delegation_dataframe()
495+
expected_df = pd.DataFrame(
496+
{
497+
"op_type": [
498+
"aten_add_tensor",
499+
"aten_mm_default",
500+
"aten_sub_tensor",
501+
"getitem",
502+
"Total",
503+
],
504+
"occurrences_in_delegated_graphs": [2, 2, 0, 0, 4],
505+
"occurrences_in_non_delegated_graphs": [0, 0, 1, 2, 3],
506+
}
507+
)
508+
assert_frame_equal(expected_df, df)

exir/backend/utils.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66

77
import logging
88
import operator
9+
import re
910
from collections import defaultdict
11+
from dataclasses import asdict, dataclass
1012
from functools import lru_cache
1113
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
1214

15+
import pandas as pd
1316
import torch
1417
from executorch.exir.backend.backend_details import ExportedProgram
1518
from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
@@ -27,6 +30,12 @@
2730
T_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
2831
T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
2932

33+
# Column names of the DataFrame returned by DelegationInfo.get_operator_delegation_dataframe()
34+
# which describes the summarized delegation information grouped by each operator type
35+
_OCCURRENCES_IN_DELEGATED_GRAPHS = "occurrences_in_delegated_graphs"
36+
_OCCURRENCES_IN_NON_DELEGATED_GRAPHS = "occurrences_in_non_delegated_graphs"
37+
38+
3039
log: logging.Logger = logging.getLogger(__name__)
3140

3241

@@ -280,6 +289,163 @@ def get_delegates(graph: torch.fx.Graph) -> List[torch.fx.Node]:
280289
]
281290

282291

292+
@dataclass
293+
class DelegationBreakdown:
294+
"""
295+
DelegationBreakdown contains the number of delegated and non-delegated nodes
296+
of the operator type op_type.
297+
298+
Args:
299+
delegated: The number of delegated nodes.
300+
non_delegated: The number of non-delegated nodes.
301+
"""
302+
303+
op_type: str = ""
304+
delegated: int = 0
305+
non_delegated: int = 0
306+
307+
308+
@dataclass
309+
class DelegationInfo:
310+
"""
311+
DelegationInfo contains information of a delegated graph module.
312+
313+
Args:
314+
num_delegated_subgraphs: The number of delegated subgraphs.
315+
num_delegated_nodes: The number of delegated nodes.
316+
num_non_delegated_nodes: The number of non-delegated nodes.
317+
delegation_by_operator: A dictionary of operator type to DelegationBreakdown.
318+
"""
319+
320+
num_delegated_subgraphs: int
321+
num_delegated_nodes: int
322+
num_non_delegated_nodes: int
323+
delegation_by_operator: Dict[str, DelegationBreakdown]
324+
325+
def get_summary(self) -> str:
326+
"""
327+
Get a summary of the delegation information in string format.
328+
329+
Args:
330+
None
331+
332+
Returns:
333+
A string containing information of some class attributes for easy print-out.
334+
"""
335+
336+
# Assemble and return the summary string
337+
summary_str = f"Total delegated subgraphs: {self.num_delegated_subgraphs}\n"
338+
summary_str += f"Number of delegated nodes: {self.num_delegated_nodes}\n"
339+
summary_str += (
340+
f"Number of non-delegated nodes: {self.num_non_delegated_nodes}\n"
341+
)
342+
return summary_str
343+
344+
def get_operator_delegation_dataframe(self) -> pd.DataFrame:
345+
"""
346+
Get the delegation information grouped by operator type in a pandas DataFrame.
347+
348+
Args:
349+
None
350+
351+
Returns:
352+
Returns a pandas DataFrame containing the following columns:
353+
- op_type: The operator type, with the last row being "Total".
354+
- occurrences_in_delegated_graphs: The number of occurrences of the op_type in delegated subgraphs.
355+
- occurrences_in_non_delegated_graphs: The number of occurrences of the op_type not in delegated subgraphs.
356+
With the last row being the total number of delegated and non-delegated occurrences of each op_type.
357+
"""
358+
359+
# Convert the dict to a dataframe
360+
list_of_dicts = [
361+
asdict(breakdown) for breakdown in self.delegation_by_operator.values()
362+
]
363+
df = pd.DataFrame(list_of_dicts)
364+
# Rename columns for better understandability
365+
df = df.rename(
366+
columns={
367+
"delegated": _OCCURRENCES_IN_DELEGATED_GRAPHS,
368+
"non_delegated": _OCCURRENCES_IN_NON_DELEGATED_GRAPHS,
369+
}
370+
)
371+
df = df.sort_values(by="op_type", ignore_index=True)
372+
373+
# Add a Total row at the bottom
374+
total_delegated_nodes = df[_OCCURRENCES_IN_DELEGATED_GRAPHS].sum()
375+
total_non_delegated_nodes = df[_OCCURRENCES_IN_NON_DELEGATED_GRAPHS].sum()
376+
df.loc[len(df)] = ["Total", total_delegated_nodes, total_non_delegated_nodes]
377+
378+
return df
379+
380+
381+
def get_delegation_info(
382+
graph_module: torch.fx.GraphModule,
383+
) -> DelegationInfo:
384+
"""
385+
Util function to get the delegation information of the given graph module.
386+
387+
Args:
388+
graph_module: The lowered graph module to get the delegation information from.
389+
390+
Returns:
391+
Return a DelegationInfo object containing the delegation information.
392+
"""
393+
394+
def _get_op_type(node_name: str) -> str:
395+
# node_name is in format <op_type> or <op_type>_x in which x is an integer suffix.
396+
return re.sub(r"_[\d]+$", "", node_name)
397+
398+
op_occurrences_dict = defaultdict(lambda: DelegationBreakdown())
399+
400+
def _insert_op_occurrences_dict(node_name: str, delegated: bool) -> None:
401+
op_type = _get_op_type(node_name)
402+
op_occurrences_dict[op_type].op_type = op_type
403+
if delegated:
404+
op_occurrences_dict[op_type].delegated += 1
405+
else:
406+
op_occurrences_dict[op_type].non_delegated += 1
407+
408+
delegated_subgraph_counter = 0
409+
410+
lowered_module_dict = {
411+
node.name: getattr(graph_module, node.name)
412+
for node in graph_module.graph.nodes
413+
if node.op == "get_attr" and node.name.startswith("lowered_module_")
414+
}
415+
416+
for node in graph_module.graph.nodes:
417+
if (
418+
node.op == "call_function"
419+
and _get_op_type(node.name) != "executorch_call_delegate"
420+
):
421+
# Non-delegated node
422+
_insert_op_occurrences_dict(node_name=node.name, delegated=False)
423+
# Check if the node is a lowered module
424+
if node.op == "get_attr" and node.name.startswith("lowered_module_"):
425+
lowered_module = lowered_module_dict[node.name]
426+
delegated_subgraph_counter += 1
427+
for node_in_lowered_module in lowered_module.original_module.graph.nodes:
428+
if node_in_lowered_module.op == "call_function":
429+
# Delegated node
430+
_insert_op_occurrences_dict(
431+
node_name=node_in_lowered_module.name, delegated=True
432+
)
433+
434+
# Calculate the total number of delegated and non-delegated nodes
435+
num_delegated_nodes = 0
436+
num_non_delegated_nodes = 0
437+
for value in op_occurrences_dict.values():
438+
num_delegated_nodes += value.delegated
439+
num_non_delegated_nodes += value.non_delegated
440+
441+
return DelegationInfo(
442+
num_delegated_nodes=num_delegated_nodes,
443+
num_non_delegated_nodes=num_non_delegated_nodes,
444+
num_delegated_subgraphs=delegated_subgraph_counter,
445+
delegation_by_operator=op_occurrences_dict,
446+
)
447+
448+
283449
def print_delegated_graph(graph_module: torch.fx.GraphModule) -> str:
284450
"""
285451
Print the graph of including lowered_module (both backend id and original graph) together with the graph module. Example output:

0 commit comments

Comments
 (0)