Skip to content

New util function to print delegation summary #2726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exir/backend/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ runtime.python_library(
"@EXECUTORCH_CLIENTS",
],
deps = [
"fbsource//third-party/pypi/pandas:pandas",
"//caffe2:torch",
"//executorch/exir:lowered_backend_module",
],
Expand Down
1 change: 1 addition & 0 deletions exir/backend/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ python_unittest(
"test_utils.py",
],
deps = [
"fbsource//third-party/pypi/pandas:pandas",
":op_partitioner_demo",
"//caffe2:torch",
"//executorch/exir:lib",
Expand Down
67 changes: 67 additions & 0 deletions exir/backend/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@

import unittest

import pandas as pd

import torch
from executorch import exir
from executorch.exir import CaptureConfig, to_edge
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.partitioner import Partitioner, PartitionResult
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
from executorch.exir.backend.utils import (
DelegationBreakdown,
get_delegates,
get_delegation_info,
get_non_lowered_nodes,
is_identical_graph,
print_delegated_graph,
Expand All @@ -22,6 +26,7 @@
)

from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops
from pandas.testing import assert_frame_equal
from torch.ao.quantization import get_default_qconfig # @manual
from torch.ao.quantization.backend_config.executorch import (
get_executorch_backend_config,
Expand Down Expand Up @@ -439,3 +444,65 @@ def forward(self, a, x, b):
graph_str,
"Expect to see the aten.mm in the delegated graph",
)

def test_get_delegation_info(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, x, b):
y = torch.mm(a, x)
z = y + b
a = z - a
y = torch.mm(a, x)
z = y + b
return z

m = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
edge = to_edge(torch.export.export(m, inputs)).to_backend(
AddMulPartitionerDemo()
)
delegation_info = get_delegation_info(edge.exported_program().graph_module)

self.assertEqual(delegation_info.num_delegated_subgraphs, 2)
self.assertEqual(delegation_info.num_delegated_nodes, 4)
self.assertEqual(delegation_info.num_non_delegated_nodes, 3)
expected_delegation_by_op_dict = {
"aten_add_tensor": DelegationBreakdown(
op_type="aten_add_tensor", delegated=2, non_delegated=0
),
"aten_mm_default": DelegationBreakdown(
op_type="aten_mm_default", delegated=2, non_delegated=0
),
"aten_sub_tensor": DelegationBreakdown(
op_type="aten_sub_tensor", delegated=0, non_delegated=1
),
"getitem": DelegationBreakdown(
op_type="getitem", delegated=0, non_delegated=2
),
}
self.assertEqual(
delegation_info.delegation_by_operator, expected_delegation_by_op_dict
)

self.assertIn(
"Total delegated subgraphs",
delegation_info.get_summary(),
)

df = delegation_info.get_operator_delegation_dataframe()
expected_df = pd.DataFrame(
{
"op_type": [
"aten_add_tensor",
"aten_mm_default",
"aten_sub_tensor",
"getitem",
"Total",
],
"occurrences_in_delegated_graphs": [2, 2, 0, 0, 4],
"occurrences_in_non_delegated_graphs": [0, 0, 1, 2, 3],
}
)
assert_frame_equal(expected_df, df)
166 changes: 166 additions & 0 deletions exir/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@

import logging
import operator
import re
from collections import defaultdict
from dataclasses import asdict, dataclass
from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union

import pandas as pd
import torch
from executorch.exir.backend.backend_details import ExportedProgram
from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
Expand All @@ -27,6 +30,12 @@
T_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default

# Column names of the DataFrame returned by DelegationInfo.get_operator_delegation_dataframe()
# which describes the summarized delegation information grouped by each operator type
_OCCURRENCES_IN_DELEGATED_GRAPHS = "occurrences_in_delegated_graphs"
_OCCURRENCES_IN_NON_DELEGATED_GRAPHS = "occurrences_in_non_delegated_graphs"


log: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -280,6 +289,163 @@ def get_delegates(graph: torch.fx.Graph) -> List[torch.fx.Node]:
]


@dataclass
class DelegationBreakdown:
"""
DelegationBreakdown contains the number of delegated and non-delegated nodes
of the operator type op_type.

Args:
delegated: The number of delegated nodes.
non_delegated: The number of non-delegated nodes.
"""

op_type: str = ""
delegated: int = 0
non_delegated: int = 0


@dataclass
class DelegationInfo:
"""
DelegationInfo contains information of a delegated graph module.

Args:
num_delegated_subgraphs: The number of delegated subgraphs.
num_delegated_nodes: The number of delegated nodes.
num_non_delegated_nodes: The number of non-delegated nodes.
delegation_by_operator: A dictionary of operator type to DelegationBreakdown.
"""

num_delegated_subgraphs: int
num_delegated_nodes: int
num_non_delegated_nodes: int
delegation_by_operator: Dict[str, DelegationBreakdown]

def get_summary(self) -> str:
"""
Get a summary of the delegation information in string format.

Args:
None

Returns:
A string containing information of some class attributes for easy print-out.
"""

# Assemble and return the summary string
summary_str = f"Total delegated subgraphs: {self.num_delegated_subgraphs}\n"
summary_str += f"Number of delegated nodes: {self.num_delegated_nodes}\n"
summary_str += (
f"Number of non-delegated nodes: {self.num_non_delegated_nodes}\n"
)
return summary_str

def get_operator_delegation_dataframe(self) -> pd.DataFrame:
"""
Get the delegation information grouped by operator type in a pandas DataFrame.

Args:
None

Returns:
Returns a pandas DataFrame containing the following columns:
- op_type: The operator type, with the last row being "Total".
- occurrences_in_delegated_graphs: The number of occurrences of the op_type in delegated subgraphs.
- occurrences_in_non_delegated_graphs: The number of occurrences of the op_type not in delegated subgraphs.
With the last row being the total number of delegated and non-delegated occurrences of each op_type.
"""

# Convert the dict to a dataframe
list_of_dicts = [
asdict(breakdown) for breakdown in self.delegation_by_operator.values()
]
df = pd.DataFrame(list_of_dicts)
# Rename columns for better understandability
df = df.rename(
columns={
"delegated": _OCCURRENCES_IN_DELEGATED_GRAPHS,
"non_delegated": _OCCURRENCES_IN_NON_DELEGATED_GRAPHS,
}
)
df = df.sort_values(by="op_type", ignore_index=True)

# Add a Total row at the bottom
total_delegated_nodes = df[_OCCURRENCES_IN_DELEGATED_GRAPHS].sum()
total_non_delegated_nodes = df[_OCCURRENCES_IN_NON_DELEGATED_GRAPHS].sum()
df.loc[len(df)] = ["Total", total_delegated_nodes, total_non_delegated_nodes]

return df


def get_delegation_info(
graph_module: torch.fx.GraphModule,
) -> DelegationInfo:
"""
Util function to get the delegation information of the given graph module.

Args:
graph_module: The lowered graph module to get the delegation information from.

Returns:
Return a DelegationInfo object containing the delegation information.
"""

def _get_op_type(node_name: str) -> str:
# node_name is in format <op_type> or <op_type>_x in which x is an integer suffix.
return re.sub(r"_[\d]+$", "", node_name)

op_occurrences_dict = defaultdict(lambda: DelegationBreakdown())

def _insert_op_occurrences_dict(node_name: str, delegated: bool) -> None:
op_type = _get_op_type(node_name)
op_occurrences_dict[op_type].op_type = op_type
if delegated:
op_occurrences_dict[op_type].delegated += 1
else:
op_occurrences_dict[op_type].non_delegated += 1

delegated_subgraph_counter = 0

lowered_module_dict = {
node.name: getattr(graph_module, node.name)
for node in graph_module.graph.nodes
if node.op == "get_attr" and node.name.startswith("lowered_module_")
}

for node in graph_module.graph.nodes:
if (
node.op == "call_function"
and _get_op_type(node.name) != "executorch_call_delegate"
):
# Non-delegated node
_insert_op_occurrences_dict(node_name=node.name, delegated=False)
# Check if the node is a lowered module
if node.op == "get_attr" and node.name.startswith("lowered_module_"):
lowered_module = lowered_module_dict[node.name]
delegated_subgraph_counter += 1
for node_in_lowered_module in lowered_module.original_module.graph.nodes:
if node_in_lowered_module.op == "call_function":
# Delegated node
_insert_op_occurrences_dict(
node_name=node_in_lowered_module.name, delegated=True
)

# Calculate the total number of delegated and non-delegated nodes
num_delegated_nodes = 0
num_non_delegated_nodes = 0
for value in op_occurrences_dict.values():
num_delegated_nodes += value.delegated
num_non_delegated_nodes += value.non_delegated

return DelegationInfo(
num_delegated_nodes=num_delegated_nodes,
num_non_delegated_nodes=num_non_delegated_nodes,
num_delegated_subgraphs=delegated_subgraph_counter,
delegation_by_operator=op_occurrences_dict,
)


def print_delegated_graph(graph_module: torch.fx.GraphModule) -> str:
"""
Print the graph of including lowered_module (both backend id and original graph) together with the graph module. Example output:
Expand Down