Skip to content

Commit d121545

Browse files
committed
fix: Address PR comments and update logging scheme
- Fix test case failures
1 parent 635e26a commit d121545

File tree

9 files changed

+181
-115
lines changed

9 files changed

+181
-115
lines changed
Lines changed: 81 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import logging
22
import math
33
from dataclasses import dataclass, field
4-
from typing import List, Tuple
4+
from typing import Any, Dict, List
55

6-
import torch
6+
from torch_tensorrt.dynamo._settings import CompilationSettings
77

88
logger = logging.getLogger(__name__)
99

@@ -15,18 +15,18 @@ class PerSubgraphData:
1515
Args:
1616
subgraph_name (str): Name of the subgraph in the GraphModule
1717
subgraph_op_count (int): Number of operations in the subgraph
18-
subgraph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the subgraph
19-
subgraph_input_dtypes (List[torch.device]): Input data types of the subgraph
20-
subgraph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the subgraph
21-
subgraph_output_dtypes (List[torch.device]): Output data types of the subgraph
18+
subgraph_input_shapes (Any): Shapes of input Tensors of the subgraph
19+
subgraph_input_dtypes (Any): Input data types of the subgraph
20+
subgraph_output_shapes (Any): Shapes of output Tensors of the subgraph
21+
subgraph_output_dtypes (Any): Output data types of the subgraph
2222
"""
2323

2424
subgraph_name: str = ""
2525
subgraph_op_count: int = 0
26-
subgraph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list)
27-
subgraph_input_dtypes: List[torch.device] = field(default_factory=list)
28-
subgraph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list)
29-
subgraph_output_dtypes: List[torch.device] = field(default_factory=list)
26+
subgraph_input_shapes: Any = field(default_factory=list)
27+
subgraph_input_dtypes: Any = field(default_factory=list)
28+
subgraph_output_shapes: Any = field(default_factory=list)
29+
subgraph_output_dtypes: Any = field(default_factory=list)
3030

3131

3232
@dataclass
@@ -36,95 +36,86 @@ class DryRunTracker:
3636
Args:
3737
total_ops_in_graph (int): Total number of operators in graph
3838
supported_ops_in_graph (int): Number of supported operators in graph
39-
graph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the graph
40-
graph_input_dtypes (List[torch.device]): Input data types of the graph
41-
graph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the graph
42-
graph_output_dtypes (List[torch.device]): Output data types of the graph
39+
graph_input_shapes (Any): Shapes of input Tensors of the graph
40+
graph_input_dtypes (Any): Input data types of the graph
41+
graph_output_shapes (Any): Shapes of output Tensors of the graph
42+
graph_output_dtypes (Any): Output data types of the graph
4343
per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class
4444
tensorrt_graph_count (int): Number of TensorRT engines to be generated
45-
truncated_long_and_double (bool): Whether truncate_long_and_double was enabled
45+
compilation_settings (CompilationSettings): User Compilation Settings
46+
unsupported_ops (Dict[str, int]): Set of operators not supported in TRT
4647
"""
4748

4849
total_ops_in_graph: int = 0
4950
supported_ops_in_graph: int = 0
50-
graph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list)
51-
graph_input_dtypes: List[torch.device] = field(default_factory=list)
52-
graph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list)
53-
graph_output_dtypes: List[torch.device] = field(default_factory=list)
51+
graph_input_shapes: Any = field(default_factory=list)
52+
graph_input_dtypes: Any = field(default_factory=list)
53+
graph_output_shapes: Any = field(default_factory=list)
54+
graph_output_dtypes: Any = field(default_factory=list)
5455
per_subgraph_data: List[PerSubgraphData] = field(default_factory=list)
5556
tensorrt_graph_count: int = 0
56-
truncated_long_and_double: bool = False
57+
compilation_settings: CompilationSettings = field(
58+
default_factory=CompilationSettings
59+
)
60+
unsupported_ops: Dict[str, int] = field(default_factory=dict)
5761

5862

5963
def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> None:
60-
"""Displays statistics about the dryrun either to debug logs or info logs"""
61-
# If user specified "dryrun=True", print to info logs, else debug
62-
if dryrun_enabled:
63-
dryrun_logger = logger.info
64-
else:
65-
dryrun_logger = logger.debug
66-
64+
"""Displays statistics about the dryrun either to debug logs or stdout"""
6765
formatted_stats = "\n"
6866

6967
# Print overall stats about the graph, operator counts, etc.
70-
formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n"
68+
formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n\n"
7169
formatted_stats += (
7270
f"The graph consists of {dryrun_tracker.total_ops_in_graph} Total Operators, "
7371
f"of which {dryrun_tracker.supported_ops_in_graph} operators are supported, "
74-
f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n"
75-
)
76-
formatted_stats += f"Long and double inputs were {'' if dryrun_tracker.truncated_long_and_double else 'not'} truncated (truncate_long_and_double={dryrun_tracker.truncated_long_and_double})\n"
77-
formatted_stats += (
78-
f"{dryrun_tracker.tensorrt_graph_count} TRT Engine(s) were generated\n"
72+
f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n\n"
7973
)
74+
formatted_stats += f"The following ops are currently unsupported and set to run in Torch: {dryrun_tracker.unsupported_ops}\n\n"
75+
formatted_stats += f"Compiled with: {dryrun_tracker.compilation_settings}\n\n"
8076

8177
assert len(dryrun_tracker.per_subgraph_data) == dryrun_tracker.tensorrt_graph_count
8278

8379
# Print schematic of the graph structure, as in:
8480
#
85-
# Inputs: [Tensor: (1, 3, 224, 224)@float32]
81+
# Inputs: List[Tensor: (1, 3, 224, 224)@float32]
8682
# ...
87-
# TRT Engine #1: _run_on_acc_0
88-
# Engine Inputs: [Tensor: (1, 3, 224, 224)@float32]
89-
# Number of Operators in Engine: 1
90-
# Engine Outputs: [Tensor: (1, 64, 112, 112)@float32]
83+
# TRT Engine #1 - Submodule name: _run_on_acc_0
84+
# Engine Inputs: List[Tensor: (1, 3, 224, 224)@float32]
85+
# Number of Operators in Engine: 1
86+
# Engine Outputs: Tensor: (1, 64, 112, 112)@float32
9187
# ...
92-
# Outputs: [Tensor: (1, 1000)@float32]
88+
# Outputs: List[Tensor: (1, 1000)@float32]
9389
#
9490
formatted_stats += " " * 2 + "Graph Structure:\n\n"
9591
formatted_stats += (
9692
" " * 3
97-
+ f"Inputs: [{input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}]\n"
93+
+ f"Inputs: {input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}\n"
9894
)
9995

10096
for i, trt_subgraph_data in enumerate(dryrun_tracker.per_subgraph_data):
101-
assert len(trt_subgraph_data.subgraph_input_dtypes) == len(
102-
trt_subgraph_data.subgraph_input_shapes
103-
)
104-
assert len(trt_subgraph_data.subgraph_output_dtypes) == len(
105-
trt_subgraph_data.subgraph_output_shapes
106-
)
10797
formatted_stats += " " * 4 + "...\n"
10898
formatted_stats += (
109-
" " * 4 + f"TRT Engine #{i+1}: {trt_subgraph_data.subgraph_name}\n"
99+
" " * 4
100+
+ f"TRT Engine #{i+1} - Submodule name: {trt_subgraph_data.subgraph_name}\n"
110101
)
111102
formatted_stats += (
112103
" " * 5
113-
+ f"Engine Inputs: [{input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}]\n"
104+
+ f"Engine Inputs: {input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}\n"
114105
)
115106
formatted_stats += (
116107
" " * 5
117108
+ f"Number of Operators in Engine: {trt_subgraph_data.subgraph_op_count}\n"
118109
)
119110
formatted_stats += (
120111
" " * 5
121-
+ f"Engine Outputs: [{input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}]\n"
112+
+ f"Engine Outputs: {input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}\n"
122113
)
123114

124115
formatted_stats += " " * 4 + "...\n"
125116
formatted_stats += (
126117
" " * 3
127-
+ f"Outputs: [{input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}]\n"
118+
+ f"Outputs: {input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}\n"
128119
)
129120

130121
# Print aggregate statistics about the graph structure, including recommended "min_block_size" options
@@ -167,23 +158,23 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
167158
+ " " * 3
168159
+ "- For minimal graph segmentation, select min_block_size="
169160
+ f"{most_ops_in_an_engine} which would generate "
170-
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engines"
161+
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engine(s)"
171162
)
172163
if math.ceil(avg_ops_per_engine) != most_ops_in_an_engine:
173164
formatted_stats += (
174165
"\n"
175166
+ " " * 3
176167
+ "- For moderate graph segmentation, select min_block_size="
177168
+ f"{math.ceil(avg_ops_per_engine)} which would generate "
178-
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engines"
169+
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engine(s)"
179170
)
180171

181172
formatted_stats += (
182173
"\n"
183174
+ " " * 3
184175
+ "- The current level of graph segmentation is equivalent to selecting min_block_size="
185176
+ f"{min_ops_in_an_engine} which generates "
186-
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engines"
177+
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engine(s)"
187178
)
188179
else:
189180
formatted_stats += (
@@ -192,14 +183,45 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
192183
+ "Aggregate stats not available since no TRT Engines were generated."
193184
)
194185

195-
dryrun_logger(formatted_stats)
186+
# If user specified "dryrun=True", print to stdout, else debug
187+
if dryrun_enabled:
188+
print(formatted_stats)
189+
else:
190+
logger.debug(formatted_stats)
196191

197192

198-
def input_formatter(shapes: List[Tuple[int, ...]], dtypes: List[torch.dtype]) -> str:
193+
def input_formatter(shapes: Any, dtypes: Any) -> str:
199194
"""Format shapes and dtypes of input Tensors into a readable string"""
200-
formatted_str = ", "
201195

202-
for shape, dtype in zip(shapes, dtypes):
203-
formatted_str += f"Tensor: {shape}@{str(dtype)[6:]}, "
196+
def input_formatter_helper(shapes: Any, dtypes: Any) -> str:
197+
"""Helper for input formatter"""
198+
# Base case - single shape, single dtype
199+
if isinstance(shapes, tuple) and all(isinstance(elt, int) for elt in shapes):
200+
return f"Tensor: {shapes}@{str(dtypes)[6:]}, "
201+
202+
# Shapes is a sequence
203+
elif isinstance(shapes, (list, tuple)):
204+
formatted_str = "List[" if isinstance(shapes, list) else "Tuple("
205+
for shape, dtype in zip(shapes, dtypes):
206+
formatted_str += input_formatter_helper(shape, dtype)
207+
formatted_str = formatted_str[:-2] + (
208+
"], " if isinstance(shapes, list) else "), "
209+
)
210+
return formatted_str
211+
212+
# Shapes is a dictionary
213+
elif isinstance(shapes, dict):
214+
formatted_str = "Dict{"
215+
216+
for key, shape in shapes.items():
217+
formatted_str += input_formatter_helper(shape, dtypes[key])
218+
219+
formatted_str = formatted_str[:-2] + "}, "
220+
return formatted_str
221+
222+
else:
223+
raise ValueError(
224+
f"Invalid input type {type(shapes)} encountered in parse_complex_tensor_structs parsing."
225+
)
204226

205-
return formatted_str[2:-2]
227+
return input_formatter_helper(shapes, dtypes)[:-2]

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
5151
from torch_tensorrt.dynamo.utils import (
5252
get_torch_inputs,
53+
parse_complex_tensor_structs,
5354
prepare_inputs,
5455
set_log_level,
5556
to_torch_device,
@@ -257,11 +258,13 @@ def compile_module(
257258

258259
dryrun_tracker.total_ops_in_graph = total_ops
259260
dryrun_tracker.supported_ops_in_graph = num_supported_ops
260-
dryrun_tracker.graph_input_shapes = [
261-
tuple(input_.shape) for input_ in sample_inputs
262-
]
263-
dryrun_tracker.graph_input_dtypes = [input_.torch_dtype for input_ in sample_inputs]
264-
dryrun_tracker.truncated_long_and_double = settings.truncate_long_and_double
261+
dryrun_tracker.graph_input_shapes = parse_complex_tensor_structs(
262+
sample_inputs, "shape", tuple
263+
)
264+
dryrun_tracker.graph_input_dtypes = parse_complex_tensor_structs(
265+
sample_inputs, "torch_dtype"
266+
)
267+
dryrun_tracker.compilation_settings = settings
265268

266269
if settings.dryrun and settings.min_block_size > 1:
267270
logger.info(
@@ -290,7 +293,7 @@ def compile_module(
290293
# If specified, try using the fast partitioner and fall back to the global one on failure
291294
if settings.use_fast_partitioner:
292295
try:
293-
partitioned_module = partitioning.fast_partition(
296+
partitioned_module, supported_ops = partitioning.fast_partition(
294297
gm,
295298
verbose=settings.debug,
296299
min_block_size=settings.min_block_size,
@@ -307,13 +310,15 @@ def compile_module(
307310
settings.use_fast_partitioner = False
308311

309312
if not settings.use_fast_partitioner:
310-
partitioned_module = partitioning.global_partition(
313+
partitioned_module, supported_ops = partitioning.global_partition(
311314
gm,
312315
verbose=settings.debug,
313316
min_block_size=settings.min_block_size,
314317
torch_executed_ops=settings.torch_executed_ops,
315318
)
316319

320+
dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators
321+
317322
# Store TRT replicas of Torch subgraphs
318323
trt_modules = {}
319324
# Iterate over all components that can be accelerated
@@ -360,25 +365,23 @@ def compile_module(
360365
name,
361366
)
362367

363-
subgraph_data.subgraph_input_dtypes = [
364-
submodule_input.torch_dtype for submodule_input in submodule_inputs
365-
]
366-
subgraph_data.subgraph_input_shapes = [
367-
tuple(submodule_input.shape) for submodule_input in submodule_inputs
368-
]
368+
subgraph_data.subgraph_input_shapes = parse_complex_tensor_structs(
369+
submodule_inputs, "shape", tuple
370+
)
371+
subgraph_data.subgraph_input_dtypes = parse_complex_tensor_structs(
372+
submodule_inputs, "torch_dtype"
373+
)
369374

370375
submodule_outputs = submodule(
371376
*get_torch_inputs(submodule_inputs, to_torch_device(settings.device))
372377
)
373-
if not isinstance(submodule_outputs, (list, tuple)):
374-
submodule_outputs = [submodule_outputs]
375378

376-
subgraph_data.subgraph_output_dtypes = [
377-
submodule_output.dtype for submodule_output in submodule_outputs
378-
]
379-
subgraph_data.subgraph_output_shapes = [
380-
tuple(submodule_output.shape) for submodule_output in submodule_outputs
381-
]
379+
subgraph_data.subgraph_output_shapes = parse_complex_tensor_structs(
380+
submodule_outputs, "shape", tuple
381+
)
382+
subgraph_data.subgraph_output_dtypes = parse_complex_tensor_structs(
383+
submodule_outputs, "dtype"
384+
)
382385

383386
dryrun_tracker.tensorrt_graph_count += 1
384387
dryrun_tracker.per_subgraph_data.append(subgraph_data)
@@ -401,10 +404,12 @@ def compile_module(
401404
if not isinstance(sample_outputs, (list, tuple)):
402405
sample_outputs = [sample_outputs]
403406

404-
dryrun_tracker.graph_output_shapes = [
405-
tuple(output_.shape) for output_ in sample_outputs
406-
]
407-
dryrun_tracker.graph_output_dtypes = [output_.dtype for output_ in sample_outputs]
407+
dryrun_tracker.graph_output_shapes = parse_complex_tensor_structs(
408+
sample_outputs, "shape", tuple
409+
)
410+
dryrun_tracker.graph_output_dtypes = parse_complex_tensor_structs(
411+
sample_outputs, "dtype"
412+
)
408413

409414
# Replace all FX Modules with TRT Modules
410415
for name, trt_module in trt_modules.items():

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def partition(
248248
min_block_size: int = MIN_BLOCK_SIZE,
249249
torch_executed_ops: Collection[Target] = set(),
250250
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
251-
) -> torch.fx.GraphModule:
251+
) -> Tuple[torch.fx.GraphModule, OpSupportTester]:
252252
"""Partition an FX GraphModule with aten ops into TRT engines
253253
Partitioning is based on converter operator support
254254
@@ -259,7 +259,7 @@ def partition(
259259
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
260260
require_full_compilation: Require that all computational operators be run in TRT
261261
Returns:
262-
torch.fx.GraphModule
262+
torch.fx.GraphModule, OpSupportTester
263263
"""
264264
# Ensure graph is clean prior to partitioning
265265
gm.graph.eliminate_dead_code()
@@ -280,4 +280,4 @@ def partition(
280280
if verbose:
281281
supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs)
282282

283-
return partitioned_graph
283+
return partitioned_graph, supported_ops

0 commit comments

Comments
 (0)