Skip to content

Commit 212019b

Browse files
committed
feat: Add optional filepath to save
- Add detailed layer information for excluded ops
1 parent d121545 commit 212019b

File tree

3 files changed

+64
-7
lines changed

3 files changed

+64
-7
lines changed

py/torch_tensorrt/dynamo/_DryRunTracker.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import logging
22
import math
3+
import operator
4+
import os
35
from dataclasses import dataclass, field
4-
from typing import Any, Dict, List
6+
from typing import Any, Dict, List, Union
57

8+
import torch
69
from torch_tensorrt.dynamo._settings import CompilationSettings
10+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry
11+
from torch_tensorrt.dynamo.conversion.converter_utils import get_node_name
712

813
logger = logging.getLogger(__name__)
914

@@ -44,6 +49,7 @@ class DryRunTracker:
4449
tensorrt_graph_count (int): Number of TensorRT engines to be generated
4550
compilation_settings (CompilationSettings): User Compilation Settings
4651
unsupported_ops (Dict[str, int]): Set of operators not supported in TRT
52+
to_run_in_torch (List[str]): Set of nodes to run in Torch
4753
"""
4854

4955
total_ops_in_graph: int = 0
@@ -58,9 +64,12 @@ class DryRunTracker:
5864
default_factory=CompilationSettings
5965
)
6066
unsupported_ops: Dict[str, int] = field(default_factory=dict)
67+
to_run_in_torch: List[str] = field(default_factory=list)
6168

6269

63-
def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> None:
70+
def dryrun_stats_display(
71+
dryrun_tracker: DryRunTracker, dryrun_enabled: Union[bool, str]
72+
) -> None:
6473
"""Displays statistics about the dryrun either to debug logs or stdout"""
6574
formatted_stats = "\n"
6675

@@ -71,7 +80,19 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
7180
f"of which {dryrun_tracker.supported_ops_in_graph} operators are supported, "
7281
f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n\n"
7382
)
74-
formatted_stats += f"The following ops are currently unsupported and set to run in Torch: {dryrun_tracker.unsupported_ops}\n\n"
83+
if dryrun_tracker.unsupported_ops:
84+
parsed_ops = "\n".join(
85+
[f"{str(k)}: {str(v)}" for k, v in dryrun_tracker.unsupported_ops.items()]
86+
)
87+
formatted_stats += f"The following ops are currently unsupported or excluded from conversion, and are listed with their op-count in the graph:\n {parsed_ops}\n\n"
88+
89+
if dryrun_tracker.to_run_in_torch:
90+
formatted_nodes = "\n".join(dryrun_tracker.to_run_in_torch)
91+
formatted_stats += (
92+
f"The following nodes are currently set to run in Torch:\n{formatted_nodes}\n"
93+
"Note: Some of the above nodes may be supported, but were not included in a TRT graph by the partitioner\n\n"
94+
)
95+
7596
formatted_stats += f"Compiled with: {dryrun_tracker.compilation_settings}\n\n"
7697

7798
assert len(dryrun_tracker.per_subgraph_data) == dryrun_tracker.tensorrt_graph_count
@@ -184,8 +205,17 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
184205
)
185206

186207
# If user specified "dryrun=True", print to stdout, else debug
208+
# If user specified a filepath, save the output to the path as well
187209
if dryrun_enabled:
188210
print(formatted_stats)
211+
if isinstance(dryrun_enabled, str):
212+
if os.path.exists(dryrun_enabled):
213+
logger.warning(
214+
f"File already exists at path {dryrun_enabled}, not saving dryrun output"
215+
)
216+
else:
217+
with open(dryrun_enabled, "w+") as f:
218+
f.write(formatted_stats)
189219
else:
190220
logger.debug(formatted_stats)
191221

@@ -225,3 +255,23 @@ def input_formatter_helper(shapes: Any, dtypes: Any) -> str:
225255
)
226256

227257
return input_formatter_helper(shapes, dtypes)[:-2]
258+
259+
260+
def parse_non_trt_nodes(graph_module: torch.fx.GraphModule) -> List[str]:
261+
"""Parses call_function and call_method nodes from a GraphModule
262+
Excludes getitem nodes
263+
264+
Returns a string representation of the nodes
265+
"""
266+
to_run_in_torch = []
267+
for node in graph_module.graph.nodes:
268+
# getitem nodes are excluded since they are a Tensor-collection op
269+
if (
270+
node.op in ("call_function", "call_method")
271+
and node.target != operator.getitem
272+
):
273+
to_run_in_torch.append(
274+
f"Node: {ConverterRegistry.qualified_name_or_str(node.target)}, "
275+
f"with layer location: {get_node_name(node)}"
276+
)
277+
return to_run_in_torch

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
DryRunTracker,
4242
PerSubgraphData,
4343
dryrun_stats_display,
44+
parse_non_trt_nodes,
4445
)
4546
from torch_tensorrt.dynamo.conversion import (
4647
CompilationSettings,
@@ -319,6 +320,10 @@ def compile_module(
319320

320321
dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators
321322

323+
# The global partitioner leaves non-TRT nodes as-is
324+
if not settings.use_fast_partitioner:
325+
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module))
326+
322327
# Store TRT replicas of Torch subgraphs
323328
trt_modules = {}
324329
# Iterate over all components that can be accelerated
@@ -327,6 +332,7 @@ def compile_module(
327332
submodule = getattr(partitioned_module, name)
328333
# Criteria for a module to be convertible to TRT
329334
if settings.use_fast_partitioner and "_run_on_acc" not in name:
335+
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule))
330336
continue
331337

332338
subgraph_data = PerSubgraphData()

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Optional, Set
2+
from typing import Optional, Set, Union
33

44
import torch
55
from tensorrt import EngineCapability
@@ -64,8 +64,9 @@ class CompilationSettings:
6464
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
6565
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
6666
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
67-
dryrun (bool): Toggle "Dryrun" mode, which runs everything through partitioning, short of conversion to
68-
TRT Engines. Prints detailed logs of the graph structure and nature of partitioning
67+
dryrun (Union[bool, str]): Toggle "Dryrun" mode, which runs everything through partitioning, short of conversion to
68+
TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the
69+
ouptut to a file if a string path is specified
6970
"""
7071

7172
precision: torch.dtype = PRECISION
@@ -91,4 +92,4 @@ class CompilationSettings:
9192
dla_sram_size: int = DLA_SRAM_SIZE
9293
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE
9394
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE
94-
dryrun: bool = DRYRUN
95+
dryrun: Union[bool, str] = DRYRUN

0 commit comments

Comments
 (0)