|
6 | 6 |
|
7 | 7 | import logging
|
8 | 8 | import operator
|
9 |
| -import re |
10 | 9 | from collections import defaultdict
|
11 |
| -from dataclasses import asdict, dataclass |
12 | 10 | from functools import lru_cache
|
13 | 11 | from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
14 | 12 |
|
15 |
| -import pandas as pd |
16 | 13 | import torch
|
17 | 14 | from executorch.exir.backend.backend_details import ExportedProgram
|
18 | 15 | from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
|
|
30 | 27 | T_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
|
31 | 28 | T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
|
32 | 29 |
|
33 |
| -# Column names of the DataFrame returned by DelegationInfo.get_operator_delegation_dataframe() |
34 |
| -# which describes the summarized delegation information grouped by each operator type |
35 |
| -_OCCURRENCES_IN_DELEGATED_GRAPHS = "occurrences_in_delegated_graphs" |
36 |
| -_OCCURRENCES_IN_NON_DELEGATED_GRAPHS = "occurrences_in_non_delegated_graphs" |
37 |
| - |
38 | 30 |
|
39 | 31 | log: logging.Logger = logging.getLogger(__name__)
|
40 | 32 |
|
@@ -291,163 +283,6 @@ def get_delegates(graph: torch.fx.Graph) -> List[torch.fx.Node]:
|
291 | 283 | ]
|
292 | 284 |
|
293 | 285 |
|
294 |
| -@dataclass |
295 |
| -class DelegationBreakdown: |
296 |
| - """ |
297 |
| - DelegationBreakdown contains the number of delegated and non-delegated nodes |
298 |
| - of the operator type op_type. |
299 |
| -
|
300 |
| - Args: |
301 |
| - delegated: The number of delegated nodes. |
302 |
| - non_delegated: The number of non-delegated nodes. |
303 |
| - """ |
304 |
| - |
305 |
| - op_type: str = "" |
306 |
| - delegated: int = 0 |
307 |
| - non_delegated: int = 0 |
308 |
| - |
309 |
| - |
310 |
| -@dataclass |
311 |
| -class DelegationInfo: |
312 |
| - """ |
313 |
| - DelegationInfo contains information of a delegated graph module. |
314 |
| -
|
315 |
| - Args: |
316 |
| - num_delegated_subgraphs: The number of delegated subgraphs. |
317 |
| - num_delegated_nodes: The number of delegated nodes. |
318 |
| - num_non_delegated_nodes: The number of non-delegated nodes. |
319 |
| - delegation_by_operator: A dictionary of operator type to DelegationBreakdown. |
320 |
| - """ |
321 |
| - |
322 |
| - num_delegated_subgraphs: int |
323 |
| - num_delegated_nodes: int |
324 |
| - num_non_delegated_nodes: int |
325 |
| - delegation_by_operator: Dict[str, DelegationBreakdown] |
326 |
| - |
327 |
| - def get_summary(self) -> str: |
328 |
| - """ |
329 |
| - Get a summary of the delegation information in string format. |
330 |
| -
|
331 |
| - Args: |
332 |
| - None |
333 |
| -
|
334 |
| - Returns: |
335 |
| - A string containing information of some class attributes for easy print-out. |
336 |
| - """ |
337 |
| - |
338 |
| - # Assemble and return the summary string |
339 |
| - summary_str = f"Total delegated subgraphs: {self.num_delegated_subgraphs}\n" |
340 |
| - summary_str += f"Number of delegated nodes: {self.num_delegated_nodes}\n" |
341 |
| - summary_str += ( |
342 |
| - f"Number of non-delegated nodes: {self.num_non_delegated_nodes}\n" |
343 |
| - ) |
344 |
| - return summary_str |
345 |
| - |
346 |
| - def get_operator_delegation_dataframe(self) -> pd.DataFrame: |
347 |
| - """ |
348 |
| - Get the delegation information grouped by operator type in a pandas DataFrame. |
349 |
| -
|
350 |
| - Args: |
351 |
| - None |
352 |
| -
|
353 |
| - Returns: |
354 |
| - Returns a pandas DataFrame containing the following columns: |
355 |
| - - op_type: The operator type, with the last row being "Total". |
356 |
| - - occurrences_in_delegated_graphs: The number of occurrences of the op_type in delegated subgraphs. |
357 |
| - - occurrences_in_non_delegated_graphs: The number of occurrences of the op_type not in delegated subgraphs. |
358 |
| - With the last row being the total number of delegated and non-delegated occurrences of each op_type. |
359 |
| - """ |
360 |
| - |
361 |
| - # Convert the dict to a dataframe |
362 |
| - list_of_dicts = [ |
363 |
| - asdict(breakdown) for breakdown in self.delegation_by_operator.values() |
364 |
| - ] |
365 |
| - df = pd.DataFrame(list_of_dicts) |
366 |
| - # Rename columns for better understandability |
367 |
| - df = df.rename( |
368 |
| - columns={ |
369 |
| - "delegated": _OCCURRENCES_IN_DELEGATED_GRAPHS, |
370 |
| - "non_delegated": _OCCURRENCES_IN_NON_DELEGATED_GRAPHS, |
371 |
| - } |
372 |
| - ) |
373 |
| - df = df.sort_values(by="op_type", ignore_index=True) |
374 |
| - |
375 |
| - # Add a Total row at the bottom |
376 |
| - total_delegated_nodes = df[_OCCURRENCES_IN_DELEGATED_GRAPHS].sum() |
377 |
| - total_non_delegated_nodes = df[_OCCURRENCES_IN_NON_DELEGATED_GRAPHS].sum() |
378 |
| - df.loc[len(df)] = ["Total", total_delegated_nodes, total_non_delegated_nodes] |
379 |
| - |
380 |
| - return df |
381 |
| - |
382 |
| - |
383 |
| -def get_delegation_info( |
384 |
| - graph_module: torch.fx.GraphModule, |
385 |
| -) -> DelegationInfo: |
386 |
| - """ |
387 |
| - Util function to get the delegation information of the given graph module. |
388 |
| -
|
389 |
| - Args: |
390 |
| - graph_module: The lowered graph module to get the delegation information from. |
391 |
| -
|
392 |
| - Returns: |
393 |
| - Return a DelegationInfo object containing the delegation information. |
394 |
| - """ |
395 |
| - |
396 |
| - def _get_op_type(node_name: str) -> str: |
397 |
| - # node_name is in format <op_type> or <op_type>_x in which x is an integer suffix. |
398 |
| - return re.sub(r"_[\d]+$", "", node_name) |
399 |
| - |
400 |
| - op_occurrences_dict = defaultdict(lambda: DelegationBreakdown()) |
401 |
| - |
402 |
| - def _insert_op_occurrences_dict(node_name: str, delegated: bool) -> None: |
403 |
| - op_type = _get_op_type(node_name) |
404 |
| - op_occurrences_dict[op_type].op_type = op_type |
405 |
| - if delegated: |
406 |
| - op_occurrences_dict[op_type].delegated += 1 |
407 |
| - else: |
408 |
| - op_occurrences_dict[op_type].non_delegated += 1 |
409 |
| - |
410 |
| - delegated_subgraph_counter = 0 |
411 |
| - |
412 |
| - lowered_module_dict = { |
413 |
| - node.name: getattr(graph_module, node.name) |
414 |
| - for node in graph_module.graph.nodes |
415 |
| - if node.op == "get_attr" and node.name.startswith("lowered_module_") |
416 |
| - } |
417 |
| - |
418 |
| - for node in graph_module.graph.nodes: |
419 |
| - if ( |
420 |
| - node.op == "call_function" |
421 |
| - and _get_op_type(node.name) != "executorch_call_delegate" |
422 |
| - ): |
423 |
| - # Non-delegated node |
424 |
| - _insert_op_occurrences_dict(node_name=node.name, delegated=False) |
425 |
| - # Check if the node is a lowered module |
426 |
| - if node.op == "get_attr" and node.name.startswith("lowered_module_"): |
427 |
| - lowered_module = lowered_module_dict[node.name] |
428 |
| - delegated_subgraph_counter += 1 |
429 |
| - for node_in_lowered_module in lowered_module.original_module.graph.nodes: |
430 |
| - if node_in_lowered_module.op == "call_function": |
431 |
| - # Delegated node |
432 |
| - _insert_op_occurrences_dict( |
433 |
| - node_name=node_in_lowered_module.name, delegated=True |
434 |
| - ) |
435 |
| - |
436 |
| - # Calculate the total number of delegated and non-delegated nodes |
437 |
| - num_delegated_nodes = 0 |
438 |
| - num_non_delegated_nodes = 0 |
439 |
| - for value in op_occurrences_dict.values(): |
440 |
| - num_delegated_nodes += value.delegated |
441 |
| - num_non_delegated_nodes += value.non_delegated |
442 |
| - |
443 |
| - return DelegationInfo( |
444 |
| - num_delegated_nodes=num_delegated_nodes, |
445 |
| - num_non_delegated_nodes=num_non_delegated_nodes, |
446 |
| - num_delegated_subgraphs=delegated_subgraph_counter, |
447 |
| - delegation_by_operator=op_occurrences_dict, |
448 |
| - ) |
449 |
| - |
450 |
| - |
451 | 286 | def print_delegated_graph(graph_module: torch.fx.GraphModule) -> None:
|
452 | 287 | """
|
453 | 288 | Print the formatted graph string.
|
|
0 commit comments