Skip to content

Commit 7115281

Browse files
committed
feat: Add optional filepath to save
- Add detailed layer information for excluded ops
1 parent e6302cf commit 7115281

File tree

5 files changed

+66
-9
lines changed

5 files changed

+66
-9
lines changed

py/torch_tensorrt/dynamo/_DryRunTracker.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import logging
22
import math
3+
import operator
4+
import os
35
from dataclasses import dataclass, field
4-
from typing import Any, Dict, List
6+
from typing import Any, Dict, List, Union
57

8+
import torch
69
from torch_tensorrt.dynamo._settings import CompilationSettings
10+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry
11+
from torch_tensorrt.dynamo.conversion.converter_utils import get_node_name
712

813
logger = logging.getLogger(__name__)
914

@@ -44,6 +49,7 @@ class DryRunTracker:
4449
tensorrt_graph_count (int): Number of TensorRT engines to be generated
4550
compilation_settings (CompilationSettings): User Compilation Settings
4651
unsupported_ops (Dict[str, int]): Set of operators not supported in TRT
52+
to_run_in_torch (List[str]): Set of nodes to run in Torch
4753
"""
4854

4955
total_ops_in_graph: int = 0
@@ -58,9 +64,12 @@ class DryRunTracker:
5864
default_factory=CompilationSettings
5965
)
6066
unsupported_ops: Dict[str, int] = field(default_factory=dict)
67+
to_run_in_torch: List[str] = field(default_factory=list)
6168

6269

63-
def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> None:
70+
def dryrun_stats_display(
71+
dryrun_tracker: DryRunTracker, dryrun_enabled: Union[bool, str]
72+
) -> None:
6473
"""Displays statistics about the dryrun either to debug logs or stdout"""
6574
formatted_stats = "\n"
6675

@@ -71,7 +80,19 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
7180
f"of which {dryrun_tracker.supported_ops_in_graph} operators are supported, "
7281
f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n\n"
7382
)
74-
formatted_stats += f"The following ops are currently unsupported and set to run in Torch: {dryrun_tracker.unsupported_ops}\n\n"
83+
if dryrun_tracker.unsupported_ops:
84+
parsed_ops = "\n".join(
85+
[f"{str(k)}: {str(v)}" for k, v in dryrun_tracker.unsupported_ops.items()]
86+
)
87+
formatted_stats += f"The following ops are currently unsupported or excluded from conversion, and are listed with their op-count in the graph:\n {parsed_ops}\n\n"
88+
89+
if dryrun_tracker.to_run_in_torch:
90+
formatted_nodes = "\n".join(dryrun_tracker.to_run_in_torch)
91+
formatted_stats += (
92+
f"The following nodes are currently set to run in Torch:\n{formatted_nodes}\n"
93+
"Note: Some of the above nodes may be supported, but were not included in a TRT graph by the partitioner\n\n"
94+
)
95+
7596
formatted_stats += f"Compiled with: {dryrun_tracker.compilation_settings}\n\n"
7697

7798
assert len(dryrun_tracker.per_subgraph_data) == dryrun_tracker.tensorrt_graph_count
@@ -184,8 +205,17 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
184205
)
185206

186207
# If user specified "dryrun=True", print to stdout, else debug
208+
# If user specified a filepath, save the output to the path as well
187209
if dryrun_enabled:
188210
print(formatted_stats)
211+
if isinstance(dryrun_enabled, str):
212+
if os.path.exists(dryrun_enabled):
213+
logger.warning(
214+
f"File already exists at path {dryrun_enabled}, not saving dryrun output"
215+
)
216+
else:
217+
with open(dryrun_enabled, "w+") as f:
218+
f.write(formatted_stats)
189219
else:
190220
logger.debug(formatted_stats)
191221

@@ -225,3 +255,23 @@ def input_formatter_helper(shapes: Any, dtypes: Any) -> str:
225255
)
226256

227257
return input_formatter_helper(shapes, dtypes)[:-2]
258+
259+
260+
def parse_non_trt_nodes(graph_module: torch.fx.GraphModule) -> List[str]:
261+
"""Parses call_function and call_method nodes from a GraphModule
262+
Excludes getitem nodes
263+
264+
Returns a string representation of the nodes
265+
"""
266+
to_run_in_torch = []
267+
for node in graph_module.graph.nodes:
268+
# getitem nodes are excluded since they are a Tensor-collection op
269+
if (
270+
node.op in ("call_function", "call_method")
271+
and node.target != operator.getitem
272+
):
273+
to_run_in_torch.append(
274+
f"Node: {ConverterRegistry.qualified_name_or_str(node.target)}, "
275+
f"with layer location: {get_node_name(node)}"
276+
)
277+
return to_run_in_torch

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
DryRunTracker,
3434
PerSubgraphData,
3535
dryrun_stats_display,
36+
parse_non_trt_nodes,
3637
)
3738
from torch_tensorrt.dynamo.conversion import (
3839
CompilationSettings,
@@ -296,6 +297,10 @@ def compile_module(
296297

297298
dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators
298299

300+
# The global partitioner leaves non-TRT nodes as-is
301+
if not settings.use_fast_partitioner:
302+
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module))
303+
299304
# Store TRT replicas of Torch subgraphs
300305
trt_modules = {}
301306
# Iterate over all components that can be accelerated
@@ -304,6 +309,7 @@ def compile_module(
304309
submodule = getattr(partitioned_module, name)
305310
# Criteria for a module to be convertible to TRT
306311
if settings.use_fast_partitioner and "_run_on_acc" not in name:
312+
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule))
307313
continue
308314

309315
subgraph_data = PerSubgraphData()

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Optional, Set
2+
from typing import Optional, Set, Union
33

44
import torch
55
from torch_tensorrt._Device import Device
@@ -47,8 +47,9 @@ class CompilationSettings:
4747
device (Device): GPU to compile the model on
4848
require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT.
4949
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
50-
dryrun (bool): Toggle "Dryrun" mode, which runs everything through partitioning, short of conversion to
51-
TRT Engines. Prints detailed logs of the graph structure and nature of partitioning
50+
dryrun (Union[bool, str]): Toggle "Dryrun" mode, which runs everything through partitioning, short of conversion to
51+
TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the
52+
ouptut to a file if a string path is specified
5253
"""
5354

5455
precision: torch.dtype = PRECISION
@@ -66,4 +67,4 @@ class CompilationSettings:
6667
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
6768
device: Device = field(default_factory=default_device)
6869
require_full_compilation: bool = REQUIRE_FULL_COMPILATION
69-
dryrun: bool = DRYRUN
70+
dryrun: Union[bool, str] = DRYRUN

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def is_node_supported(
4242
node_name = ConverterRegistry.qualified_name_or_str(node.target)
4343

4444
if (
45-
node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name)
45+
node in CONVERTERS or (node.op == "get_attr")
4646
) and node_name not in self.torch_executed_ops:
4747
# If node is a proper, supported computational node, store the operator
4848
if not node.is_impure():

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def is_node_supported(
150150
node_name = ConverterRegistry.qualified_name_or_str(node.target)
151151

152152
if (
153-
node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name)
153+
node in CONVERTERS or (node.op == "get_attr")
154154
) and node_name not in self.torch_executed_ops:
155155
# If node is a proper, supported computational node, store the operator
156156
if not node.is_impure():

0 commit comments

Comments
 (0)