Skip to content

Commit 635e26a

Browse files
committed
feat: Add dryrun feature to Dynamo paths
- Enables building of TRT engines with "dryrun" capabilities, meaning all of the phases except conversion are run and verbose logs of the graph structure and composition are printed for the user - Improves general-purpose debug logging by printing dryrun stats to the debug logs regardless of option specification - Provides intuitive schematic of the graph engines, inputs, and code path through the course of the graph
1 parent 8554782 commit 635e26a

File tree

4 files changed

+294
-9
lines changed

4 files changed

+294
-9
lines changed
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import logging
2+
import math
3+
from dataclasses import dataclass, field
4+
from typing import List, Tuple
5+
6+
import torch
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
@dataclass
12+
class PerSubgraphData:
13+
"""Class to track data on a per-subgraph level
14+
15+
Args:
16+
subgraph_name (str): Name of the subgraph in the GraphModule
17+
subgraph_op_count (int): Number of operations in the subgraph
18+
subgraph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the subgraph
19+
subgraph_input_dtypes (List[torch.device]): Input data types of the subgraph
20+
subgraph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the subgraph
21+
subgraph_output_dtypes (List[torch.device]): Output data types of the subgraph
22+
"""
23+
24+
subgraph_name: str = ""
25+
subgraph_op_count: int = 0
26+
subgraph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list)
27+
subgraph_input_dtypes: List[torch.device] = field(default_factory=list)
28+
subgraph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list)
29+
subgraph_output_dtypes: List[torch.device] = field(default_factory=list)
30+
31+
32+
@dataclass
33+
class DryRunTracker:
34+
"""Class to track data on a graph-wide level
35+
36+
Args:
37+
total_ops_in_graph (int): Total number of operators in graph
38+
supported_ops_in_graph (int): Number of supported operators in graph
39+
graph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the graph
40+
graph_input_dtypes (List[torch.device]): Input data types of the graph
41+
graph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the graph
42+
graph_output_dtypes (List[torch.device]): Output data types of the graph
43+
per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class
44+
tensorrt_graph_count (int): Number of TensorRT engines to be generated
45+
truncated_long_and_double (bool): Whether truncate_long_and_double was enabled
46+
"""
47+
48+
total_ops_in_graph: int = 0
49+
supported_ops_in_graph: int = 0
50+
graph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list)
51+
graph_input_dtypes: List[torch.device] = field(default_factory=list)
52+
graph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list)
53+
graph_output_dtypes: List[torch.device] = field(default_factory=list)
54+
per_subgraph_data: List[PerSubgraphData] = field(default_factory=list)
55+
tensorrt_graph_count: int = 0
56+
truncated_long_and_double: bool = False
57+
58+
59+
def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> None:
60+
"""Displays statistics about the dryrun either to debug logs or info logs"""
61+
# If user specified "dryrun=True", print to info logs, else debug
62+
if dryrun_enabled:
63+
dryrun_logger = logger.info
64+
else:
65+
dryrun_logger = logger.debug
66+
67+
formatted_stats = "\n"
68+
69+
# Print overall stats about the graph, operator counts, etc.
70+
formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n"
71+
formatted_stats += (
72+
f"The graph consists of {dryrun_tracker.total_ops_in_graph} Total Operators, "
73+
f"of which {dryrun_tracker.supported_ops_in_graph} operators are supported, "
74+
f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n"
75+
)
76+
formatted_stats += f"Long and double inputs were {'' if dryrun_tracker.truncated_long_and_double else 'not'} truncated (truncate_long_and_double={dryrun_tracker.truncated_long_and_double})\n"
77+
formatted_stats += (
78+
f"{dryrun_tracker.tensorrt_graph_count} TRT Engine(s) were generated\n"
79+
)
80+
81+
assert len(dryrun_tracker.per_subgraph_data) == dryrun_tracker.tensorrt_graph_count
82+
83+
# Print schematic of the graph structure, as in:
84+
#
85+
# Inputs: [Tensor: (1, 3, 224, 224)@float32]
86+
# ...
87+
# TRT Engine #1: _run_on_acc_0
88+
# Engine Inputs: [Tensor: (1, 3, 224, 224)@float32]
89+
# Number of Operators in Engine: 1
90+
# Engine Outputs: [Tensor: (1, 64, 112, 112)@float32]
91+
# ...
92+
# Outputs: [Tensor: (1, 1000)@float32]
93+
#
94+
formatted_stats += " " * 2 + "Graph Structure:\n\n"
95+
formatted_stats += (
96+
" " * 3
97+
+ f"Inputs: [{input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}]\n"
98+
)
99+
100+
for i, trt_subgraph_data in enumerate(dryrun_tracker.per_subgraph_data):
101+
assert len(trt_subgraph_data.subgraph_input_dtypes) == len(
102+
trt_subgraph_data.subgraph_input_shapes
103+
)
104+
assert len(trt_subgraph_data.subgraph_output_dtypes) == len(
105+
trt_subgraph_data.subgraph_output_shapes
106+
)
107+
formatted_stats += " " * 4 + "...\n"
108+
formatted_stats += (
109+
" " * 4 + f"TRT Engine #{i+1}: {trt_subgraph_data.subgraph_name}\n"
110+
)
111+
formatted_stats += (
112+
" " * 5
113+
+ f"Engine Inputs: [{input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}]\n"
114+
)
115+
formatted_stats += (
116+
" " * 5
117+
+ f"Number of Operators in Engine: {trt_subgraph_data.subgraph_op_count}\n"
118+
)
119+
formatted_stats += (
120+
" " * 5
121+
+ f"Engine Outputs: [{input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}]\n"
122+
)
123+
124+
formatted_stats += " " * 4 + "...\n"
125+
formatted_stats += (
126+
" " * 3
127+
+ f"Outputs: [{input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}]\n"
128+
)
129+
130+
# Print aggregate statistics about the graph structure, including recommended "min_block_size" options
131+
if dryrun_tracker.tensorrt_graph_count > 0:
132+
min_ops_in_an_engine = min(
133+
trt_subgraph.subgraph_op_count
134+
for trt_subgraph in dryrun_tracker.per_subgraph_data
135+
)
136+
avg_ops_per_engine = (
137+
sum(
138+
trt_subgraph.subgraph_op_count
139+
for trt_subgraph in dryrun_tracker.per_subgraph_data
140+
)
141+
/ dryrun_tracker.tensorrt_graph_count
142+
)
143+
avg_ops_per_engine = round(avg_ops_per_engine, 2)
144+
most_ops_in_an_engine = max(
145+
trt_subgraph.subgraph_op_count
146+
for trt_subgraph in dryrun_tracker.per_subgraph_data
147+
)
148+
149+
formatted_stats += "\n" + " " * 2 + "-" * 25 + " Aggregate Stats " + "-" * 25
150+
formatted_stats += (
151+
"\n\n"
152+
+ " " * 3
153+
+ "Average Number of Operators per TRT Engine: "
154+
+ f"{avg_ops_per_engine}"
155+
)
156+
157+
formatted_stats += (
158+
"\n"
159+
+ " " * 3
160+
+ "Most Operators in a TRT Engine: "
161+
+ f"{most_ops_in_an_engine}"
162+
)
163+
164+
formatted_stats += "\n\n" + " " * 2 + "*" * 10 + " Recommendations " + "*" * 10
165+
formatted_stats += (
166+
"\n\n"
167+
+ " " * 3
168+
+ "- For minimal graph segmentation, select min_block_size="
169+
+ f"{most_ops_in_an_engine} which would generate "
170+
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engines"
171+
)
172+
if math.ceil(avg_ops_per_engine) != most_ops_in_an_engine:
173+
formatted_stats += (
174+
"\n"
175+
+ " " * 3
176+
+ "- For moderate graph segmentation, select min_block_size="
177+
+ f"{math.ceil(avg_ops_per_engine)} which would generate "
178+
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engines"
179+
)
180+
181+
formatted_stats += (
182+
"\n"
183+
+ " " * 3
184+
+ "- The current level of graph segmentation is equivalent to selecting min_block_size="
185+
+ f"{min_ops_in_an_engine} which generates "
186+
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engines"
187+
)
188+
else:
189+
formatted_stats += (
190+
"\n"
191+
+ " " * 2
192+
+ "Aggregate stats not available since no TRT Engines were generated."
193+
)
194+
195+
dryrun_logger(formatted_stats)
196+
197+
198+
def input_formatter(shapes: List[Tuple[int, ...]], dtypes: List[torch.dtype]) -> str:
199+
"""Format shapes and dtypes of input Tensors into a readable string"""
200+
formatted_str = ", "
201+
202+
for shape, dtype in zip(shapes, dtypes):
203+
formatted_str += f"Tensor: {shape}@{str(dtype)[6:]}, "
204+
205+
return formatted_str[2:-2]

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 84 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Any, List, Optional, Sequence, Set, Tuple, Union
66

77
import torch
8-
import torch_tensorrt
98
from torch.export import ExportedProgram
109
from torch_tensorrt._Device import Device
1110
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
@@ -20,6 +19,7 @@
2019
DLA_GLOBAL_DRAM_SIZE,
2120
DLA_LOCAL_DRAM_SIZE,
2221
DLA_SRAM_SIZE,
22+
DRYRUN,
2323
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
2424
ENGINE_CAPABILITY,
2525
MAX_AUX_STREAMS,
@@ -37,6 +37,11 @@
3737
VERSION_COMPATIBLE,
3838
WORKSPACE_SIZE,
3939
)
40+
from torch_tensorrt.dynamo._DryRunTracker import (
41+
DryRunTracker,
42+
PerSubgraphData,
43+
dryrun_stats_display,
44+
)
4045
from torch_tensorrt.dynamo.conversion import (
4146
CompilationSettings,
4247
convert_module,
@@ -51,6 +56,8 @@
5156
to_torch_tensorrt_device,
5257
)
5358

59+
import torch_tensorrt
60+
5461
logger = logging.getLogger(__name__)
5562

5663

@@ -84,6 +91,7 @@ def compile(
8491
use_python_runtime: bool = USE_PYTHON_RUNTIME,
8592
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
8693
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
94+
dryrun: bool = DRYRUN,
8795
**kwargs: Any,
8896
) -> torch.fx.GraphModule:
8997
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
@@ -140,6 +148,7 @@ def compile(
140148
use_python_runtime: (bool): Return a graph using a pure Python runtime, reduces options for serialization
141149
use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optiminal. Use the global paritioner (``False``) if looking for best performance
142150
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
151+
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
143152
**kwargs: Any,
144153
Returns:
145154
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -215,6 +224,7 @@ def compile(
215224
"dla_sram_size": dla_sram_size,
216225
"dla_local_dram_size": dla_local_dram_size,
217226
"dla_global_dram_size": dla_global_dram_size,
227+
"dryrun": dryrun,
218228
}
219229

220230
settings = CompilationSettings(**compilation_options)
@@ -238,15 +248,32 @@ def compile_module(
238248
Returns:
239249
Compiled FX GraphModule
240250
"""
251+
dryrun_tracker = DryRunTracker()
241252

242253
# Check the number of supported operations in the graph
243254
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
244255
gm, settings.debug, settings.torch_executed_ops
245256
)
246257

258+
dryrun_tracker.total_ops_in_graph = total_ops
259+
dryrun_tracker.supported_ops_in_graph = num_supported_ops
260+
dryrun_tracker.graph_input_shapes = [
261+
tuple(input_.shape) for input_ in sample_inputs
262+
]
263+
dryrun_tracker.graph_input_dtypes = [input_.torch_dtype for input_ in sample_inputs]
264+
dryrun_tracker.truncated_long_and_double = settings.truncate_long_and_double
265+
266+
if settings.dryrun and settings.min_block_size > 1:
267+
logger.info(
268+
"It is recommended to run `dryrun` mode with `min_block_size=1`, "
269+
"for the most thorough analysis"
270+
)
271+
247272
# If the number of supported operations is 0 or less than the block size, skip the subgraph
248273
# TODO: Add condition to second expression below when require_full_compilation is added
249-
if num_supported_ops == 0 or (num_supported_ops < settings.min_block_size):
274+
if num_supported_ops == 0 or (
275+
num_supported_ops < settings.min_block_size and not settings.dryrun
276+
):
250277
logger.warning(
251278
f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. "
252279
f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}"
@@ -297,6 +324,16 @@ def compile_module(
297324
if settings.use_fast_partitioner and "_run_on_acc" not in name:
298325
continue
299326

327+
subgraph_data = PerSubgraphData()
328+
subgraph_data.subgraph_name = name
329+
subgraph_data.subgraph_op_count = len(
330+
[
331+
node
332+
for node in submodule.graph.nodes
333+
if node.op in ("call_function", "call_method", "call_module")
334+
]
335+
)
336+
300337
# Get the submodule inputs for min, opt, max shapes of the graph inputs
301338
submodule_inputs = partitioning.get_submod_inputs(
302339
partitioned_module,
@@ -323,15 +360,51 @@ def compile_module(
323360
name,
324361
)
325362

326-
# Create TRT engines from submodule
327-
trt_module = convert_module(
328-
submodule,
329-
submodule_inputs,
330-
settings=settings,
331-
name=name,
363+
subgraph_data.subgraph_input_dtypes = [
364+
submodule_input.torch_dtype for submodule_input in submodule_inputs
365+
]
366+
subgraph_data.subgraph_input_shapes = [
367+
tuple(submodule_input.shape) for submodule_input in submodule_inputs
368+
]
369+
370+
submodule_outputs = submodule(
371+
*get_torch_inputs(submodule_inputs, to_torch_device(settings.device))
332372
)
373+
if not isinstance(submodule_outputs, (list, tuple)):
374+
submodule_outputs = [submodule_outputs]
333375

334-
trt_modules[name] = trt_module
376+
subgraph_data.subgraph_output_dtypes = [
377+
submodule_output.dtype for submodule_output in submodule_outputs
378+
]
379+
subgraph_data.subgraph_output_shapes = [
380+
tuple(submodule_output.shape) for submodule_output in submodule_outputs
381+
]
382+
383+
dryrun_tracker.tensorrt_graph_count += 1
384+
dryrun_tracker.per_subgraph_data.append(subgraph_data)
385+
386+
# Create TRT engines from submodule
387+
if not settings.dryrun:
388+
trt_module = convert_module(
389+
submodule,
390+
submodule_inputs,
391+
settings=settings,
392+
name=name,
393+
)
394+
395+
trt_modules[name] = trt_module
396+
397+
sample_outputs = gm(
398+
*get_torch_inputs(sample_inputs, to_torch_device(settings.device))
399+
)
400+
401+
if not isinstance(sample_outputs, (list, tuple)):
402+
sample_outputs = [sample_outputs]
403+
404+
dryrun_tracker.graph_output_shapes = [
405+
tuple(output_.shape) for output_ in sample_outputs
406+
]
407+
dryrun_tracker.graph_output_dtypes = [output_.dtype for output_ in sample_outputs]
335408

336409
# Replace all FX Modules with TRT Modules
337410
for name, trt_module in trt_modules.items():
@@ -341,4 +414,6 @@ def compile_module(
341414
if fast_partitioner_failed:
342415
settings.use_fast_partitioner = True
343416

417+
dryrun_stats_display(dryrun_tracker, settings.dryrun)
418+
344419
return partitioned_module

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
2525
REFIT = False
2626
REQUIRE_FULL_COMPILATION = False
27+
DRYRUN = False
2728

2829

2930
def default_device() -> Device:

0 commit comments

Comments
 (0)