|
6 | 6 |
|
7 | 7 | import logging
|
8 | 8 | import operator
|
| 9 | +import re |
9 | 10 | from collections import defaultdict
|
| 11 | +from dataclasses import dataclass |
10 | 12 | from functools import lru_cache
|
11 | 13 | from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
12 | 14 |
|
| 15 | +import pandas as pd |
13 | 16 | import torch
|
14 | 17 | from executorch.exir.backend.backend_details import ExportedProgram
|
15 | 18 | from executorch.exir.common import setting_python_recursive_limit
|
@@ -211,6 +214,130 @@ def get_delegates(graph: torch.fx.Graph) -> List[torch.fx.Node]:
|
211 | 214 | ]
|
212 | 215 |
|
213 | 216 |
|
| 217 | +@dataclass |
| 218 | +class DelegationBreakdown: |
| 219 | + delegated: int |
| 220 | + non_delegated: int |
| 221 | + |
| 222 | + |
| 223 | +@dataclass |
| 224 | +class DelegationInfo: |
| 225 | + num_delegated_subgraphs: int |
| 226 | + num_delegated_nodes: int |
| 227 | + num_non_delegated_nodes: int |
| 228 | + delegation_by_operator: Dict[str, DelegationBreakdown] |
| 229 | + |
| 230 | + def get_summary(self) -> str: |
| 231 | + # Assemble and return the summary string |
| 232 | + summary_str = f"Total delegated subgraphs: {self.num_delegated_subgraphs}\n" |
| 233 | + summary_str += f"Number of delegated nodes: {self.num_delegated_nodes}\n" |
| 234 | + summary_str += ( |
| 235 | + f"Number of non-delegated nodes: {self.num_non_delegated_nodes}\n" |
| 236 | + ) |
| 237 | + return summary_str |
| 238 | + |
| 239 | + def get_operator_delegation_dataframe(self) -> pd.DataFrame: |
| 240 | + # Convert the dict to a dataframe |
| 241 | + df = pd.DataFrame( |
| 242 | + list(self.delegation_by_operator.items()), |
| 243 | + columns=["op type", "occurrences"], |
| 244 | + ) |
| 245 | + |
| 246 | + # Function to extract delegated and non-delegated fields |
| 247 | + def _extract_breakdown(row): |
| 248 | + return pd.Series( |
| 249 | + [row["occurrences"].delegated, row["occurrences"].non_delegated] |
| 250 | + ) |
| 251 | + |
| 252 | + df[ |
| 253 | + ["occurrences in delegated graphs", "occurrences in non-delegated graphs"] |
| 254 | + ] = df.apply(lambda row: _extract_breakdown(row), axis=1) |
| 255 | + df.drop(columns=["occurrences"], inplace=True) |
| 256 | + |
| 257 | + # Add a Total row at the bottom |
| 258 | + total_delegated_nodes = df["occurrences in delegated graphs"].sum() |
| 259 | + total_non_delegated_nodes = df["occurrences in non-delegated graphs"].sum() |
| 260 | + df = df.sort_values(by="op type", ignore_index=True) |
| 261 | + df.loc["Total"] = ["Total", total_delegated_nodes, total_non_delegated_nodes] |
| 262 | + |
| 263 | + return df |
| 264 | + |
| 265 | + |
| 266 | +def _get_op_type(node_name: str) -> str: |
| 267 | + # node_name is in format <op_type> or <op_type>_x in which x is an integer suffix. |
| 268 | + pattern = r"_(\d+)$" |
| 269 | + match = re.search(pattern, node_name) |
| 270 | + if match: |
| 271 | + return node_name[: match.start()] |
| 272 | + else: |
| 273 | + return node_name |
| 274 | + |
| 275 | + |
| 276 | +def get_delegation_info( |
| 277 | + graph_module: torch.fx.GraphModule, |
| 278 | +) -> DelegationInfo: |
| 279 | + """ |
| 280 | + Returns a string (for high level summary) and a DataFrame (for per op type summary) of the |
| 281 | + delegation info on the given graph module. |
| 282 | + """ |
| 283 | + |
| 284 | + op_occurrences_dict = {} |
| 285 | + |
| 286 | + def _insert_op_occurrences_dict(node_name: str, delegated: bool) -> None: |
| 287 | + op_type = _get_op_type(node_name) |
| 288 | + if op_type not in op_occurrences_dict: |
| 289 | + if delegated: |
| 290 | + op_occurrences_dict[op_type] = DelegationBreakdown( |
| 291 | + delegated=1, non_delegated=0 |
| 292 | + ) |
| 293 | + else: |
| 294 | + op_occurrences_dict[op_type] = DelegationBreakdown( |
| 295 | + delegated=0, non_delegated=1 |
| 296 | + ) |
| 297 | + else: |
| 298 | + if delegated: |
| 299 | + op_occurrences_dict[op_type].delegated += 1 |
| 300 | + else: |
| 301 | + op_occurrences_dict[op_type].non_delegated += 1 |
| 302 | + |
| 303 | + delegated_subgraph_counter = 0 |
| 304 | + |
| 305 | + lowered_module_dict = { |
| 306 | + node.name: getattr(graph_module, node.name) |
| 307 | + for node in graph_module.graph.nodes |
| 308 | + if node.op == "get_attr" and node.name.startswith("lowered_module_") |
| 309 | + } |
| 310 | + |
| 311 | + for node in graph_module.graph.nodes: |
| 312 | + if ( |
| 313 | + node.op == "call_function" |
| 314 | + and _get_op_type(node.name) != "executorch_call_delegate" |
| 315 | + ): |
| 316 | + _insert_op_occurrences_dict(node_name=node.name, delegated=False) |
| 317 | + if node.op == "get_attr" and node.name.startswith("lowered_module_"): |
| 318 | + lowered_module = lowered_module_dict[node.name] |
| 319 | + delegated_subgraph_counter += 1 |
| 320 | + for node_in_lowered_module in lowered_module.original_module.graph.nodes: |
| 321 | + if node_in_lowered_module.op == "call_function": |
| 322 | + _insert_op_occurrences_dict( |
| 323 | + node_name=node_in_lowered_module.name, delegated=True |
| 324 | + ) |
| 325 | + |
| 326 | + # Calculate the total number of delegated and non-delegated nodes |
| 327 | + num_delegated_nodes = 0 |
| 328 | + num_non_delegated_nodes = 0 |
| 329 | + for _, value in op_occurrences_dict.items(): |
| 330 | + num_delegated_nodes += value.delegated |
| 331 | + num_non_delegated_nodes += value.non_delegated |
| 332 | + |
| 333 | + return DelegationInfo( |
| 334 | + num_delegated_nodes=num_delegated_nodes, |
| 335 | + num_non_delegated_nodes=num_non_delegated_nodes, |
| 336 | + num_delegated_subgraphs=delegated_subgraph_counter, |
| 337 | + delegation_by_operator=op_occurrences_dict, |
| 338 | + ) |
| 339 | + |
| 340 | + |
214 | 341 | def print_delegated_graph(graph_module: torch.fx.GraphModule) -> str:
|
215 | 342 | """
|
216 | 343 | Print the graph of including lowered_module (both backend id and original graph) together with the graph module. Example output:
|
|
0 commit comments