|
6 | 6 |
|
7 | 7 | import logging
|
8 | 8 | import operator
|
| 9 | +import re |
9 | 10 | from collections import defaultdict
|
| 11 | +from dataclasses import dataclass |
10 | 12 | from functools import lru_cache
|
11 | 13 | from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
12 | 14 |
|
| 15 | +import pandas as pd |
13 | 16 | import torch
|
14 | 17 | from executorch.exir.backend.backend_details import ExportedProgram
|
15 | 18 | from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
|
@@ -280,6 +283,172 @@ def get_delegates(graph: torch.fx.Graph) -> List[torch.fx.Node]:
|
280 | 283 | ]
|
281 | 284 |
|
282 | 285 |
|
| 286 | +@dataclass |
| 287 | +class DelegationBreakdown: |
| 288 | + """ |
| 289 | + DelegationBreakdown contains for the number of delegated nodes and non-delegated nodes |
| 290 | + of an operator type. |
| 291 | +
|
| 292 | + Args: |
| 293 | + delegated: The number of delegated nodes. |
| 294 | + non_delegated: The number of non-delegated nodes. |
| 295 | + """ |
| 296 | + |
| 297 | + delegated: int |
| 298 | + non_delegated: int |
| 299 | + |
| 300 | + |
| 301 | +@dataclass |
| 302 | +class DelegationInfo: |
| 303 | + """ |
| 304 | + DelegationInfo contains information of a delegated graph module. |
| 305 | +
|
| 306 | + Args: |
| 307 | + num_delegated_subgraphs: The number of delegated subgraphs. |
| 308 | + num_delegated_nodes: The number of delegated nodes. |
| 309 | + num_non_delegated_nodes: The number of non-delegated nodes. |
| 310 | + delegation_by_operator: A dictionary of operator type to DelegationBreakdown. |
| 311 | + """ |
| 312 | + |
| 313 | + num_delegated_subgraphs: int |
| 314 | + num_delegated_nodes: int |
| 315 | + num_non_delegated_nodes: int |
| 316 | + delegation_by_operator: Dict[str, DelegationBreakdown] |
| 317 | + |
| 318 | + def get_summary(self) -> str: |
| 319 | + """ |
| 320 | + Get a summary of the delegation information in string format. |
| 321 | +
|
| 322 | + Args: |
| 323 | + None |
| 324 | +
|
| 325 | + Returns: |
| 326 | + A string containing information of some class attributes for easy print-out. |
| 327 | + """ |
| 328 | + |
| 329 | + # Assemble and return the summary string |
| 330 | + summary_str = f"Total delegated subgraphs: {self.num_delegated_subgraphs}\n" |
| 331 | + summary_str += f"Number of delegated nodes: {self.num_delegated_nodes}\n" |
| 332 | + summary_str += ( |
| 333 | + f"Number of non-delegated nodes: {self.num_non_delegated_nodes}\n" |
| 334 | + ) |
| 335 | + return summary_str |
| 336 | + |
| 337 | + def get_operator_delegation_dataframe(self) -> pd.DataFrame: |
| 338 | + """ |
| 339 | + Get the delegation information grouped by operator type in a pandas DataFrame. |
| 340 | +
|
| 341 | + Args: |
| 342 | + None |
| 343 | +
|
| 344 | + Returns: |
| 345 | + Returns a pandas DataFrame containing the following columns: |
| 346 | + - op_type: The operator type, with the last row being "Total". |
| 347 | + - occurrences_in_delegated_graphs: The number of occurrences of the op_type in delegated subgraphs. |
| 348 | + - occurrences_in_non_delegated_graphs: The number of occurrences of the op_type not in delegated subgraphs. |
| 349 | + With the last row being the total number of delegated and non-delegated occurrences of each op_type. |
| 350 | + """ |
| 351 | + |
| 352 | + # Convert the dict to a dataframe |
| 353 | + df = pd.DataFrame( |
| 354 | + list(self.delegation_by_operator.items()), |
| 355 | + columns=["op_type", "occurrences"], |
| 356 | + ) |
| 357 | + |
| 358 | + # Function to extract delegated and non-delegated fields |
| 359 | + def _extract_breakdown(row): |
| 360 | + return pd.Series( |
| 361 | + [row["occurrences"].delegated, row["occurrences"].non_delegated] |
| 362 | + ) |
| 363 | + |
| 364 | + df[ |
| 365 | + ["occurrences_in_delegated_graphs", "occurrences_in_non_delegated_graphs"] |
| 366 | + ] = df.apply(lambda row: _extract_breakdown(row), axis=1) |
| 367 | + df.drop(columns=["occurrences"], inplace=True) |
| 368 | + df = df.sort_values(by="op_type", ignore_index=True) |
| 369 | + |
| 370 | + # Add a Total row at the bottom |
| 371 | + total_delegated_nodes = df["occurrences_in_delegated_graphs"].sum() |
| 372 | + total_non_delegated_nodes = df["occurrences_in_non_delegated_graphs"].sum() |
| 373 | + df.loc[len(df)] = ["Total", total_delegated_nodes, total_non_delegated_nodes] |
| 374 | + |
| 375 | + return df |
| 376 | + |
| 377 | + |
| 378 | +def get_delegation_info( |
| 379 | + graph_module: torch.fx.GraphModule, |
| 380 | +) -> DelegationInfo: |
| 381 | + """ |
| 382 | + Util function to get the delegation information of the given graph module. |
| 383 | +
|
| 384 | + Args: |
| 385 | + graph_module: The lowered graph module to get the delegation information from. |
| 386 | +
|
| 387 | + Returns: |
| 388 | + Return a DelegationInfo object containing the delegation information. |
| 389 | + """ |
| 390 | + |
| 391 | + def _get_op_type(node_name: str) -> str: |
| 392 | + # node_name is in format <op_type> or <op_type>_x in which x is an integer suffix. |
| 393 | + return re.sub(r"_[\d]+$", "", node_name) |
| 394 | + |
| 395 | + op_occurrences_dict = {} |
| 396 | + |
| 397 | + def _insert_op_occurrences_dict(node_name: str, delegated: bool) -> None: |
| 398 | + op_type = _get_op_type(node_name) |
| 399 | + if op_type not in op_occurrences_dict: |
| 400 | + if delegated: |
| 401 | + op_occurrences_dict[op_type] = DelegationBreakdown( |
| 402 | + delegated=1, non_delegated=0 |
| 403 | + ) |
| 404 | + else: |
| 405 | + op_occurrences_dict[op_type] = DelegationBreakdown( |
| 406 | + delegated=0, non_delegated=1 |
| 407 | + ) |
| 408 | + else: |
| 409 | + if delegated: |
| 410 | + op_occurrences_dict[op_type].delegated += 1 |
| 411 | + else: |
| 412 | + op_occurrences_dict[op_type].non_delegated += 1 |
| 413 | + |
| 414 | + delegated_subgraph_counter = 0 |
| 415 | + |
| 416 | + lowered_module_dict = { |
| 417 | + node.name: getattr(graph_module, node.name) |
| 418 | + for node in graph_module.graph.nodes |
| 419 | + if node.op == "get_attr" and node.name.startswith("lowered_module_") |
| 420 | + } |
| 421 | + |
| 422 | + for node in graph_module.graph.nodes: |
| 423 | + if ( |
| 424 | + node.op == "call_function" |
| 425 | + and _get_op_type(node.name) != "executorch_call_delegate" |
| 426 | + ): |
| 427 | + _insert_op_occurrences_dict(node_name=node.name, delegated=False) |
| 428 | + if node.op == "get_attr" and node.name.startswith("lowered_module_"): |
| 429 | + lowered_module = lowered_module_dict[node.name] |
| 430 | + delegated_subgraph_counter += 1 |
| 431 | + for node_in_lowered_module in lowered_module.original_module.graph.nodes: |
| 432 | + if node_in_lowered_module.op == "call_function": |
| 433 | + _insert_op_occurrences_dict( |
| 434 | + node_name=node_in_lowered_module.name, delegated=True |
| 435 | + ) |
| 436 | + |
| 437 | + # Calculate the total number of delegated and non-delegated nodes |
| 438 | + num_delegated_nodes = 0 |
| 439 | + num_non_delegated_nodes = 0 |
| 440 | + for _, value in op_occurrences_dict.items(): |
| 441 | + num_delegated_nodes += value.delegated |
| 442 | + num_non_delegated_nodes += value.non_delegated |
| 443 | + |
| 444 | + return DelegationInfo( |
| 445 | + num_delegated_nodes=num_delegated_nodes, |
| 446 | + num_non_delegated_nodes=num_non_delegated_nodes, |
| 447 | + num_delegated_subgraphs=delegated_subgraph_counter, |
| 448 | + delegation_by_operator=op_occurrences_dict, |
| 449 | + ) |
| 450 | + |
| 451 | + |
283 | 452 | def print_delegated_graph(graph_module: torch.fx.GraphModule) -> str:
|
284 | 453 | """
|
285 | 454 | Print the graph of including lowered_module (both backend id and original graph) together with the graph module. Example output:
|
|
0 commit comments