|
6 | 6 |
|
7 | 7 | import logging
|
8 | 8 | import operator
|
| 9 | +import re |
9 | 10 | from collections import defaultdict
|
| 11 | +from dataclasses import asdict, dataclass |
10 | 12 | from functools import lru_cache
|
11 | 13 | from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
12 | 14 |
|
| 15 | +import pandas as pd |
13 | 16 | import torch
|
14 | 17 | from executorch.exir.backend.backend_details import ExportedProgram
|
15 | 18 | from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
|
|
27 | 30 | T_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
|
28 | 31 | T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
|
29 | 32 |
|
| 33 | +# Column names of the DataFrame returned by DelegationInfo.get_operator_delegation_dataframe() |
| 34 | +# which describes the summarized delegation information grouped by each operator type |
| 35 | +_OCCURRENCES_IN_DELEGATED_GRAPHS = "occurrences_in_delegated_graphs" |
| 36 | +_OCCURRENCES_IN_NON_DELEGATED_GRAPHS = "occurrences_in_non_delegated_graphs" |
| 37 | + |
| 38 | + |
30 | 39 | log: logging.Logger = logging.getLogger(__name__)
|
31 | 40 |
|
32 | 41 |
|
@@ -280,6 +289,163 @@ def get_delegates(graph: torch.fx.Graph) -> List[torch.fx.Node]:
|
280 | 289 | ]
|
281 | 290 |
|
282 | 291 |
|
| 292 | +@dataclass |
| 293 | +class DelegationBreakdown: |
| 294 | + """ |
| 295 | + DelegationBreakdown contains the number of delegated and non-delegated nodes |
| 296 | + of the operator type op_type. |
| 297 | +
|
| 298 | + Args: |
| 299 | + delegated: The number of delegated nodes. |
| 300 | + non_delegated: The number of non-delegated nodes. |
| 301 | + """ |
| 302 | + |
| 303 | + op_type: str = "" |
| 304 | + delegated: int = 0 |
| 305 | + non_delegated: int = 0 |
| 306 | + |
| 307 | + |
| 308 | +@dataclass |
| 309 | +class DelegationInfo: |
| 310 | + """ |
| 311 | + DelegationInfo contains information of a delegated graph module. |
| 312 | +
|
| 313 | + Args: |
| 314 | + num_delegated_subgraphs: The number of delegated subgraphs. |
| 315 | + num_delegated_nodes: The number of delegated nodes. |
| 316 | + num_non_delegated_nodes: The number of non-delegated nodes. |
| 317 | + delegation_by_operator: A dictionary of operator type to DelegationBreakdown. |
| 318 | + """ |
| 319 | + |
| 320 | + num_delegated_subgraphs: int |
| 321 | + num_delegated_nodes: int |
| 322 | + num_non_delegated_nodes: int |
| 323 | + delegation_by_operator: Dict[str, DelegationBreakdown] |
| 324 | + |
| 325 | + def get_summary(self) -> str: |
| 326 | + """ |
| 327 | + Get a summary of the delegation information in string format. |
| 328 | +
|
| 329 | + Args: |
| 330 | + None |
| 331 | +
|
| 332 | + Returns: |
| 333 | + A string containing information of some class attributes for easy print-out. |
| 334 | + """ |
| 335 | + |
| 336 | + # Assemble and return the summary string |
| 337 | + summary_str = f"Total delegated subgraphs: {self.num_delegated_subgraphs}\n" |
| 338 | + summary_str += f"Number of delegated nodes: {self.num_delegated_nodes}\n" |
| 339 | + summary_str += ( |
| 340 | + f"Number of non-delegated nodes: {self.num_non_delegated_nodes}\n" |
| 341 | + ) |
| 342 | + return summary_str |
| 343 | + |
| 344 | + def get_operator_delegation_dataframe(self) -> pd.DataFrame: |
| 345 | + """ |
| 346 | + Get the delegation information grouped by operator type in a pandas DataFrame. |
| 347 | +
|
| 348 | + Args: |
| 349 | + None |
| 350 | +
|
| 351 | + Returns: |
| 352 | + Returns a pandas DataFrame containing the following columns: |
| 353 | + - op_type: The operator type, with the last row being "Total". |
| 354 | + - occurrences_in_delegated_graphs: The number of occurrences of the op_type in delegated subgraphs. |
| 355 | + - occurrences_in_non_delegated_graphs: The number of occurrences of the op_type not in delegated subgraphs. |
| 356 | + With the last row being the total number of delegated and non-delegated occurrences of each op_type. |
| 357 | + """ |
| 358 | + |
| 359 | + # Convert the dict to a dataframe |
| 360 | + list_of_dicts = [ |
| 361 | + asdict(breakdown) for breakdown in self.delegation_by_operator.values() |
| 362 | + ] |
| 363 | + df = pd.DataFrame(list_of_dicts) |
| 364 | + # Rename columns for better understandability |
| 365 | + df = df.rename( |
| 366 | + columns={ |
| 367 | + "delegated": _OCCURRENCES_IN_DELEGATED_GRAPHS, |
| 368 | + "non_delegated": _OCCURRENCES_IN_NON_DELEGATED_GRAPHS, |
| 369 | + } |
| 370 | + ) |
| 371 | + df = df.sort_values(by="op_type", ignore_index=True) |
| 372 | + |
| 373 | + # Add a Total row at the bottom |
| 374 | + total_delegated_nodes = df[_OCCURRENCES_IN_DELEGATED_GRAPHS].sum() |
| 375 | + total_non_delegated_nodes = df[_OCCURRENCES_IN_NON_DELEGATED_GRAPHS].sum() |
| 376 | + df.loc[len(df)] = ["Total", total_delegated_nodes, total_non_delegated_nodes] |
| 377 | + |
| 378 | + return df |
| 379 | + |
| 380 | + |
| 381 | +def get_delegation_info( |
| 382 | + graph_module: torch.fx.GraphModule, |
| 383 | +) -> DelegationInfo: |
| 384 | + """ |
| 385 | + Util function to get the delegation information of the given graph module. |
| 386 | +
|
| 387 | + Args: |
| 388 | + graph_module: The lowered graph module to get the delegation information from. |
| 389 | +
|
| 390 | + Returns: |
| 391 | + Return a DelegationInfo object containing the delegation information. |
| 392 | + """ |
| 393 | + |
| 394 | + def _get_op_type(node_name: str) -> str: |
| 395 | + # node_name is in format <op_type> or <op_type>_x in which x is an integer suffix. |
| 396 | + return re.sub(r"_[\d]+$", "", node_name) |
| 397 | + |
| 398 | + op_occurrences_dict = defaultdict(lambda: DelegationBreakdown()) |
| 399 | + |
| 400 | + def _insert_op_occurrences_dict(node_name: str, delegated: bool) -> None: |
| 401 | + op_type = _get_op_type(node_name) |
| 402 | + op_occurrences_dict[op_type].op_type = op_type |
| 403 | + if delegated: |
| 404 | + op_occurrences_dict[op_type].delegated += 1 |
| 405 | + else: |
| 406 | + op_occurrences_dict[op_type].non_delegated += 1 |
| 407 | + |
| 408 | + delegated_subgraph_counter = 0 |
| 409 | + |
| 410 | + lowered_module_dict = { |
| 411 | + node.name: getattr(graph_module, node.name) |
| 412 | + for node in graph_module.graph.nodes |
| 413 | + if node.op == "get_attr" and node.name.startswith("lowered_module_") |
| 414 | + } |
| 415 | + |
| 416 | + for node in graph_module.graph.nodes: |
| 417 | + if ( |
| 418 | + node.op == "call_function" |
| 419 | + and _get_op_type(node.name) != "executorch_call_delegate" |
| 420 | + ): |
| 421 | + # Non-delegated node |
| 422 | + _insert_op_occurrences_dict(node_name=node.name, delegated=False) |
| 423 | + # Check if the node is a lowered module |
| 424 | + if node.op == "get_attr" and node.name.startswith("lowered_module_"): |
| 425 | + lowered_module = lowered_module_dict[node.name] |
| 426 | + delegated_subgraph_counter += 1 |
| 427 | + for node_in_lowered_module in lowered_module.original_module.graph.nodes: |
| 428 | + if node_in_lowered_module.op == "call_function": |
| 429 | + # Delegated node |
| 430 | + _insert_op_occurrences_dict( |
| 431 | + node_name=node_in_lowered_module.name, delegated=True |
| 432 | + ) |
| 433 | + |
| 434 | + # Calculate the total number of delegated and non-delegated nodes |
| 435 | + num_delegated_nodes = 0 |
| 436 | + num_non_delegated_nodes = 0 |
| 437 | + for value in op_occurrences_dict.values(): |
| 438 | + num_delegated_nodes += value.delegated |
| 439 | + num_non_delegated_nodes += value.non_delegated |
| 440 | + |
| 441 | + return DelegationInfo( |
| 442 | + num_delegated_nodes=num_delegated_nodes, |
| 443 | + num_non_delegated_nodes=num_non_delegated_nodes, |
| 444 | + num_delegated_subgraphs=delegated_subgraph_counter, |
| 445 | + delegation_by_operator=op_occurrences_dict, |
| 446 | + ) |
| 447 | + |
| 448 | + |
283 | 449 | def print_delegated_graph(graph_module: torch.fx.GraphModule) -> str:
|
284 | 450 | """
|
285 | 451 | Print the graph of including lowered_module (both backend id and original graph) together with the graph module. Example output:
|
|
0 commit comments