Skip to content

Commit c847dc5

Browse files
committed
fix: Address PR comments and update logging scheme
- Fix test case failures
1 parent f29b4bd commit c847dc5

File tree

9 files changed

+181
-115
lines changed

9 files changed

+181
-115
lines changed
Lines changed: 81 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import logging
22
import math
33
from dataclasses import dataclass, field
4-
from typing import List, Tuple
4+
from typing import Any, Dict, List
55

6-
import torch
6+
from torch_tensorrt.dynamo._settings import CompilationSettings
77

88
logger = logging.getLogger(__name__)
99

@@ -15,18 +15,18 @@ class PerSubgraphData:
1515
Args:
1616
subgraph_name (str): Name of the subgraph in the GraphModule
1717
subgraph_op_count (int): Number of operations in the subgraph
18-
subgraph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the subgraph
19-
subgraph_input_dtypes (List[torch.device]): Input data types of the subgraph
20-
subgraph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the subgraph
21-
subgraph_output_dtypes (List[torch.device]): Output data types of the subgraph
18+
subgraph_input_shapes (Any): Shapes of input Tensors of the subgraph
19+
subgraph_input_dtypes (Any): Input data types of the subgraph
20+
subgraph_output_shapes (Any): Shapes of output Tensors of the subgraph
21+
subgraph_output_dtypes (Any): Output data types of the subgraph
2222
"""
2323

2424
subgraph_name: str = ""
2525
subgraph_op_count: int = 0
26-
subgraph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list)
27-
subgraph_input_dtypes: List[torch.device] = field(default_factory=list)
28-
subgraph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list)
29-
subgraph_output_dtypes: List[torch.device] = field(default_factory=list)
26+
subgraph_input_shapes: Any = field(default_factory=list)
27+
subgraph_input_dtypes: Any = field(default_factory=list)
28+
subgraph_output_shapes: Any = field(default_factory=list)
29+
subgraph_output_dtypes: Any = field(default_factory=list)
3030

3131

3232
@dataclass
@@ -36,95 +36,86 @@ class DryRunTracker:
3636
Args:
3737
total_ops_in_graph (int): Total number of operators in graph
3838
supported_ops_in_graph (int): Number of supported operators in graph
39-
graph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the graph
40-
graph_input_dtypes (List[torch.device]): Input data types of the graph
41-
graph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the graph
42-
graph_output_dtypes (List[torch.device]): Output data types of the graph
39+
graph_input_shapes (Any): Shapes of input Tensors of the graph
40+
graph_input_dtypes (Any): Input data types of the graph
41+
graph_output_shapes (Any): Shapes of output Tensors of the graph
42+
graph_output_dtypes (Any): Output data types of the graph
4343
per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class
4444
tensorrt_graph_count (int): Number of TensorRT engines to be generated
45-
truncated_long_and_double (bool): Whether truncate_long_and_double was enabled
45+
compilation_settings (CompilationSettings): User Compilation Settings
46+
unsupported_ops (Dict[str, int]): Set of operators not supported in TRT
4647
"""
4748

4849
total_ops_in_graph: int = 0
4950
supported_ops_in_graph: int = 0
50-
graph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list)
51-
graph_input_dtypes: List[torch.device] = field(default_factory=list)
52-
graph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list)
53-
graph_output_dtypes: List[torch.device] = field(default_factory=list)
51+
graph_input_shapes: Any = field(default_factory=list)
52+
graph_input_dtypes: Any = field(default_factory=list)
53+
graph_output_shapes: Any = field(default_factory=list)
54+
graph_output_dtypes: Any = field(default_factory=list)
5455
per_subgraph_data: List[PerSubgraphData] = field(default_factory=list)
5556
tensorrt_graph_count: int = 0
56-
truncated_long_and_double: bool = False
57+
compilation_settings: CompilationSettings = field(
58+
default_factory=CompilationSettings
59+
)
60+
unsupported_ops: Dict[str, int] = field(default_factory=dict)
5761

5862

5963
def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> None:
60-
"""Displays statistics about the dryrun either to debug logs or info logs"""
61-
# If user specified "dryrun=True", print to info logs, else debug
62-
if dryrun_enabled:
63-
dryrun_logger = logger.info
64-
else:
65-
dryrun_logger = logger.debug
66-
64+
"""Displays statistics about the dryrun either to debug logs or stdout"""
6765
formatted_stats = "\n"
6866

6967
# Print overall stats about the graph, operator counts, etc.
70-
formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n"
68+
formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n\n"
7169
formatted_stats += (
7270
f"The graph consists of {dryrun_tracker.total_ops_in_graph} Total Operators, "
7371
f"of which {dryrun_tracker.supported_ops_in_graph} operators are supported, "
74-
f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n"
75-
)
76-
formatted_stats += f"Long and double inputs were {'' if dryrun_tracker.truncated_long_and_double else 'not'} truncated (truncate_long_and_double={dryrun_tracker.truncated_long_and_double})\n"
77-
formatted_stats += (
78-
f"{dryrun_tracker.tensorrt_graph_count} TRT Engine(s) were generated\n"
72+
f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n\n"
7973
)
74+
formatted_stats += f"The following ops are currently unsupported and set to run in Torch: {dryrun_tracker.unsupported_ops}\n\n"
75+
formatted_stats += f"Compiled with: {dryrun_tracker.compilation_settings}\n\n"
8076

8177
assert len(dryrun_tracker.per_subgraph_data) == dryrun_tracker.tensorrt_graph_count
8278

8379
# Print schematic of the graph structure, as in:
8480
#
85-
# Inputs: [Tensor: (1, 3, 224, 224)@float32]
81+
# Inputs: List[Tensor: (1, 3, 224, 224)@float32]
8682
# ...
87-
# TRT Engine #1: _run_on_acc_0
88-
# Engine Inputs: [Tensor: (1, 3, 224, 224)@float32]
89-
# Number of Operators in Engine: 1
90-
# Engine Outputs: [Tensor: (1, 64, 112, 112)@float32]
83+
# TRT Engine #1 - Submodule name: _run_on_acc_0
84+
# Engine Inputs: List[Tensor: (1, 3, 224, 224)@float32]
85+
# Number of Operators in Engine: 1
86+
# Engine Outputs: Tensor: (1, 64, 112, 112)@float32
9187
# ...
92-
# Outputs: [Tensor: (1, 1000)@float32]
88+
# Outputs: List[Tensor: (1, 1000)@float32]
9389
#
9490
formatted_stats += " " * 2 + "Graph Structure:\n\n"
9591
formatted_stats += (
9692
" " * 3
97-
+ f"Inputs: [{input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}]\n"
93+
+ f"Inputs: {input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}\n"
9894
)
9995

10096
for i, trt_subgraph_data in enumerate(dryrun_tracker.per_subgraph_data):
101-
assert len(trt_subgraph_data.subgraph_input_dtypes) == len(
102-
trt_subgraph_data.subgraph_input_shapes
103-
)
104-
assert len(trt_subgraph_data.subgraph_output_dtypes) == len(
105-
trt_subgraph_data.subgraph_output_shapes
106-
)
10797
formatted_stats += " " * 4 + "...\n"
10898
formatted_stats += (
109-
" " * 4 + f"TRT Engine #{i+1}: {trt_subgraph_data.subgraph_name}\n"
99+
" " * 4
100+
+ f"TRT Engine #{i+1} - Submodule name: {trt_subgraph_data.subgraph_name}\n"
110101
)
111102
formatted_stats += (
112103
" " * 5
113-
+ f"Engine Inputs: [{input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}]\n"
104+
+ f"Engine Inputs: {input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}\n"
114105
)
115106
formatted_stats += (
116107
" " * 5
117108
+ f"Number of Operators in Engine: {trt_subgraph_data.subgraph_op_count}\n"
118109
)
119110
formatted_stats += (
120111
" " * 5
121-
+ f"Engine Outputs: [{input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}]\n"
112+
+ f"Engine Outputs: {input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}\n"
122113
)
123114

124115
formatted_stats += " " * 4 + "...\n"
125116
formatted_stats += (
126117
" " * 3
127-
+ f"Outputs: [{input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}]\n"
118+
+ f"Outputs: {input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}\n"
128119
)
129120

130121
# Print aggregate statistics about the graph structure, including recommended "min_block_size" options
@@ -167,23 +158,23 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
167158
+ " " * 3
168159
+ "- For minimal graph segmentation, select min_block_size="
169160
+ f"{most_ops_in_an_engine} which would generate "
170-
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engines"
161+
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engine(s)"
171162
)
172163
if math.ceil(avg_ops_per_engine) != most_ops_in_an_engine:
173164
formatted_stats += (
174165
"\n"
175166
+ " " * 3
176167
+ "- For moderate graph segmentation, select min_block_size="
177168
+ f"{math.ceil(avg_ops_per_engine)} which would generate "
178-
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engines"
169+
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engine(s)"
179170
)
180171

181172
formatted_stats += (
182173
"\n"
183174
+ " " * 3
184175
+ "- The current level of graph segmentation is equivalent to selecting min_block_size="
185176
+ f"{min_ops_in_an_engine} which generates "
186-
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engines"
177+
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engine(s)"
187178
)
188179
else:
189180
formatted_stats += (
@@ -192,14 +183,45 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
192183
+ "Aggregate stats not available since no TRT Engines were generated."
193184
)
194185

195-
dryrun_logger(formatted_stats)
186+
# If user specified "dryrun=True", print to stdout, else debug
187+
if dryrun_enabled:
188+
print(formatted_stats)
189+
else:
190+
logger.debug(formatted_stats)
196191

197192

198-
def input_formatter(shapes: List[Tuple[int, ...]], dtypes: List[torch.dtype]) -> str:
193+
def input_formatter(shapes: Any, dtypes: Any) -> str:
199194
"""Format shapes and dtypes of input Tensors into a readable string"""
200-
formatted_str = ", "
201195

202-
for shape, dtype in zip(shapes, dtypes):
203-
formatted_str += f"Tensor: {shape}@{str(dtype)[6:]}, "
196+
def input_formatter_helper(shapes: Any, dtypes: Any) -> str:
197+
"""Helper for input formatter"""
198+
# Base case - single shape, single dtype
199+
if isinstance(shapes, tuple) and all(isinstance(elt, int) for elt in shapes):
200+
return f"Tensor: {shapes}@{str(dtypes)[6:]}, "
201+
202+
# Shapes is a sequence
203+
elif isinstance(shapes, (list, tuple)):
204+
formatted_str = "List[" if isinstance(shapes, list) else "Tuple("
205+
for shape, dtype in zip(shapes, dtypes):
206+
formatted_str += input_formatter_helper(shape, dtype)
207+
formatted_str = formatted_str[:-2] + (
208+
"], " if isinstance(shapes, list) else "), "
209+
)
210+
return formatted_str
211+
212+
# Shapes is a dictionary
213+
elif isinstance(shapes, dict):
214+
formatted_str = "Dict{"
215+
216+
for key, shape in shapes.items():
217+
formatted_str += input_formatter_helper(shape, dtypes[key])
218+
219+
formatted_str = formatted_str[:-2] + "}, "
220+
return formatted_str
221+
222+
else:
223+
raise ValueError(
224+
f"Invalid input type {type(shapes)} encountered in parse_complex_tensor_structs parsing."
225+
)
204226

205-
return formatted_str[2:-2]
227+
return input_formatter_helper(shapes, dtypes)[:-2]

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
4343
from torch_tensorrt.dynamo.utils import (
4444
get_torch_inputs,
45+
parse_complex_tensor_structs,
4546
prepare_inputs,
4647
set_log_level,
4748
to_torch_device,
@@ -241,11 +242,13 @@ def compile_module(
241242

242243
dryrun_tracker.total_ops_in_graph = total_ops
243244
dryrun_tracker.supported_ops_in_graph = num_supported_ops
244-
dryrun_tracker.graph_input_shapes = [
245-
tuple(input_.shape) for input_ in sample_inputs
246-
]
247-
dryrun_tracker.graph_input_dtypes = [input_.torch_dtype for input_ in sample_inputs]
248-
dryrun_tracker.truncated_long_and_double = settings.truncate_long_and_double
245+
dryrun_tracker.graph_input_shapes = parse_complex_tensor_structs(
246+
sample_inputs, "shape", tuple
247+
)
248+
dryrun_tracker.graph_input_dtypes = parse_complex_tensor_structs(
249+
sample_inputs, "torch_dtype"
250+
)
251+
dryrun_tracker.compilation_settings = settings
249252

250253
if settings.dryrun and settings.min_block_size > 1:
251254
logger.info(
@@ -274,7 +277,7 @@ def compile_module(
274277
# If specified, try using the fast partitioner and fall back to the global one on failure
275278
if settings.use_fast_partitioner:
276279
try:
277-
partitioned_module = partitioning.fast_partition(
280+
partitioned_module, supported_ops = partitioning.fast_partition(
278281
gm,
279282
verbose=settings.debug,
280283
min_block_size=settings.min_block_size,
@@ -291,13 +294,15 @@ def compile_module(
291294
settings.use_fast_partitioner = False
292295

293296
if not settings.use_fast_partitioner:
294-
partitioned_module = partitioning.global_partition(
297+
partitioned_module, supported_ops = partitioning.global_partition(
295298
gm,
296299
verbose=settings.debug,
297300
min_block_size=settings.min_block_size,
298301
torch_executed_ops=settings.torch_executed_ops,
299302
)
300303

304+
dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators
305+
301306
# Store TRT replicas of Torch subgraphs
302307
trt_modules = {}
303308
# Iterate over all components that can be accelerated
@@ -344,25 +349,23 @@ def compile_module(
344349
name,
345350
)
346351

347-
subgraph_data.subgraph_input_dtypes = [
348-
submodule_input.torch_dtype for submodule_input in submodule_inputs
349-
]
350-
subgraph_data.subgraph_input_shapes = [
351-
tuple(submodule_input.shape) for submodule_input in submodule_inputs
352-
]
352+
subgraph_data.subgraph_input_shapes = parse_complex_tensor_structs(
353+
submodule_inputs, "shape", tuple
354+
)
355+
subgraph_data.subgraph_input_dtypes = parse_complex_tensor_structs(
356+
submodule_inputs, "torch_dtype"
357+
)
353358

354359
submodule_outputs = submodule(
355360
*get_torch_inputs(submodule_inputs, to_torch_device(settings.device))
356361
)
357-
if not isinstance(submodule_outputs, (list, tuple)):
358-
submodule_outputs = [submodule_outputs]
359362

360-
subgraph_data.subgraph_output_dtypes = [
361-
submodule_output.dtype for submodule_output in submodule_outputs
362-
]
363-
subgraph_data.subgraph_output_shapes = [
364-
tuple(submodule_output.shape) for submodule_output in submodule_outputs
365-
]
363+
subgraph_data.subgraph_output_shapes = parse_complex_tensor_structs(
364+
submodule_outputs, "shape", tuple
365+
)
366+
subgraph_data.subgraph_output_dtypes = parse_complex_tensor_structs(
367+
submodule_outputs, "dtype"
368+
)
366369

367370
dryrun_tracker.tensorrt_graph_count += 1
368371
dryrun_tracker.per_subgraph_data.append(subgraph_data)
@@ -385,10 +388,12 @@ def compile_module(
385388
if not isinstance(sample_outputs, (list, tuple)):
386389
sample_outputs = [sample_outputs]
387390

388-
dryrun_tracker.graph_output_shapes = [
389-
tuple(output_.shape) for output_ in sample_outputs
390-
]
391-
dryrun_tracker.graph_output_dtypes = [output_.dtype for output_ in sample_outputs]
391+
dryrun_tracker.graph_output_shapes = parse_complex_tensor_structs(
392+
sample_outputs, "shape", tuple
393+
)
394+
dryrun_tracker.graph_output_dtypes = parse_complex_tensor_structs(
395+
sample_outputs, "dtype"
396+
)
392397

393398
# Replace all FX Modules with TRT Modules
394399
for name, trt_module in trt_modules.items():

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def partition(
248248
min_block_size: int = MIN_BLOCK_SIZE,
249249
torch_executed_ops: Collection[Target] = set(),
250250
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
251-
) -> torch.fx.GraphModule:
251+
) -> Tuple[torch.fx.GraphModule, OpSupportTester]:
252252
"""Partition an FX GraphModule with aten ops into TRT engines
253253
Partitioning is based on converter operator support
254254
@@ -259,7 +259,7 @@ def partition(
259259
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
260260
require_full_compilation: Require that all computational operators be run in TRT
261261
Returns:
262-
torch.fx.GraphModule
262+
torch.fx.GraphModule, OpSupportTester
263263
"""
264264
# Ensure graph is clean prior to partitioning
265265
gm.graph.eliminate_dead_code()
@@ -280,4 +280,4 @@ def partition(
280280
if verbose:
281281
supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs)
282282

283-
return partitioned_graph
283+
return partitioned_graph, supported_ops

0 commit comments

Comments
 (0)