Skip to content

Commit 7ef0b13

Browse files
authored
cherry-pick: Dynamo upgrades and bugfixes (release/1.4) (#1956)
1 parent 93f4be4 commit 7ef0b13

File tree

12 files changed

+91
-30
lines changed

12 files changed

+91
-30
lines changed

py/torch_tensorrt/dynamo/backend/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
DEBUG,
1717
MAX_WORKSPACE_SIZE,
1818
MIN_BLOCK_SIZE,
19+
PASS_THROUGH_BUILD_FAILURES,
1920
)
2021

2122

@@ -46,11 +47,14 @@ def compile(
4647
torch_executed_modules=[],
4748
**kwargs,
4849
):
50+
if debug:
51+
logger.setLevel(logging.DEBUG)
4952

5053
logger.warn(
5154
"The Dynamo backend is an experimental feature, for which only the "
5255
+ "following arguments are supported: "
53-
+ "{enabled_precisions, debug, workspace_size, min_block_size, torch_executed_ops}"
56+
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
57+
+ "torch_executed_ops, pass_through_build_failures}"
5458
)
5559

5660
if not isinstance(inputs, collections.abc.Sequence):
@@ -104,6 +108,7 @@ def create_backend(
104108
workspace_size: int = MAX_WORKSPACE_SIZE,
105109
min_block_size: int = MIN_BLOCK_SIZE,
106110
torch_executed_ops: Sequence[str] = set(),
111+
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
107112
**kwargs,
108113
):
109114
"""Create torch.compile backend given specified arguments
@@ -116,12 +121,16 @@ def create_backend(
116121
Returns:
117122
Backend for torch.compile
118123
"""
124+
if debug:
125+
logger.setLevel(logging.DEBUG)
126+
119127
settings = CompilationSettings(
120128
debug=debug,
121129
precision=precision,
122130
workspace_size=workspace_size,
123131
min_block_size=min_block_size,
124132
torch_executed_ops=torch_executed_ops,
133+
pass_through_build_failures=pass_through_build_failures,
125134
)
126135

127136
return partial(

py/torch_tensorrt/dynamo/backend/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
DEBUG = False
66
MAX_WORKSPACE_SIZE = 20 << 30
77
MIN_BLOCK_SIZE = 5
8+
PASS_THROUGH_BUILD_FAILURES = False

py/torch_tensorrt/dynamo/backend/_settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
DEBUG,
88
MAX_WORKSPACE_SIZE,
99
MIN_BLOCK_SIZE,
10+
PASS_THROUGH_BUILD_FAILURES,
1011
)
1112

1213

@@ -17,3 +18,4 @@ class CompilationSettings:
1718
workspace_size: int = MAX_WORKSPACE_SIZE
1819
min_block_size: int = MIN_BLOCK_SIZE
1920
torch_executed_ops: Sequence[str] = field(default_factory=set)
21+
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import logging
12
from typing import Sequence
23
import torch
3-
import traceback
44
from functools import partial
55
import torch._dynamo as td
66

@@ -19,6 +19,9 @@
1919
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
2020

2121

22+
logger = logging.getLogger(__name__)
23+
24+
2225
@td.register_backend(name="torch_tensorrt")
2326
@fake_tensor_unsupported
2427
def torch_tensorrt_backend(
@@ -52,6 +55,7 @@ def aot_torch_tensorrt_aten_backend(
5255
)
5356

5457

58+
@fake_tensor_unsupported
5559
def _pretraced_backend(
5660
gm: torch.fx.GraphModule,
5761
sample_inputs: Sequence[torch.Tensor],
@@ -74,12 +78,22 @@ def _pretraced_backend(
7478
)
7579
return trt_compiled
7680
except:
77-
traceback.print_exc()
78-
print(
81+
logger.error(
7982
"FX2TRT conversion failed on the subgraph. See trace above. "
80-
+ "Returning GraphModule forward instead."
83+
+ "Returning GraphModule forward instead.",
84+
exc_info=True,
8185
)
82-
return gm.forward
86+
87+
if not settings.pass_through_build_failures:
88+
return gm.forward
89+
else:
90+
raise AssertionError(
91+
"Halting compilation on build failure since "
92+
+ "pass_through_build_failures was specified as True. "
93+
+ "To return the default Torch implementation and avoid "
94+
+ "halting compilation on engine build failures, "
95+
+ "specify pass_through_build_failures=False."
96+
)
8397

8498

8599
def _compile_module(
@@ -120,9 +134,7 @@ def _compile_module(
120134
trt_mod = convert_module(
121135
submodule,
122136
submodule_inputs,
123-
debug=settings.debug,
124-
workspace_size=settings.workspace_size,
125-
precision=settings.precision,
137+
settings=settings,
126138
)
127139

128140
# Replace FX Module with TRT Module

py/torch_tensorrt/dynamo/backend/conversion.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,41 @@
22
import torch
33
from torch_tensorrt.fx.trt_module import TRTModule
44
from torch_tensorrt import TRTModuleNext
5+
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
56
from torch_tensorrt.fx.fx2trt import (
67
InputTensorSpec,
78
TRTInterpreter,
89
)
9-
from torch_tensorrt.fx.utils import LowerPrecision
1010

1111
import tensorrt as trt
1212

1313

1414
def convert_module(
1515
module: torch.fx.GraphModule,
1616
inputs: Sequence[torch.Tensor],
17-
debug: bool = False,
18-
workspace_size: int = 20 << 30,
19-
precision: LowerPrecision = LowerPrecision.FP32,
17+
settings: CompilationSettings = CompilationSettings(),
2018
) -> Union[TRTModuleNext, TRTModule]:
2119
"""Convert an FX module to a TRT module
2220
Args:
2321
module: FX GraphModule to convert
2422
inputs: Sequence of Tensors representing inputs to the module
25-
debug: Whether to print out verbose debugging information
26-
workspace_size: Maximum workspace TRT is allowed to use for the module
27-
precision: Model Layer precision
23+
settings: Compilation settings
2824
Returns:
2925
TRTModule or TRTModuleNext
3026
"""
3127
interp = TRTInterpreter(
3228
module,
3329
InputTensorSpec.from_tensors(inputs),
3430
explicit_batch_dimension=True,
35-
logger_level=(trt.Logger.VERBOSE if debug else trt.Logger.WARNING),
31+
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
3632
)
3733

3834
r = interp.run(
39-
max_workspace_size=workspace_size,
40-
lower_precision=precision,
35+
max_workspace_size=settings.workspace_size,
36+
lower_precision=settings.precision,
4137
profiling_verbosity=(
4238
trt.ProfilingVerbosity.VERBOSE
43-
if debug
39+
if settings.debug
4440
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
4541
),
4642
)

py/torch_tensorrt/dynamo/backend/lowering/_partition.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,18 @@ def print_support_overview(self, num_trt_blocks: Optional[int] = None):
136136
f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}"
137137
)
138138

139-
logger.debug("\nSupported Nodes:")
139+
# Reformat support messages for debugger to print node overview as a single string
140+
supported_nodes_str = "\nSupported Nodes:\n"
140141
for node_name in self.supported_operators:
141-
logger.debug("-", node_name)
142+
supported_nodes_str += f"- {node_name}\n"
143+
144+
logger.debug(supported_nodes_str)
142145

143146
if len(self.unsupported_operators) != 0:
144-
logger.debug("\nUnsupported or Excluded Nodes:")
147+
unsupported_nodes_str = "\nUnsupported or Excluded Nodes:\n"
145148
for node_name in self.unsupported_operators:
146-
logger.debug("-", node_name)
147-
logger.debug("\n")
149+
unsupported_nodes_str += f"- {node_name}\n"
150+
logger.debug(unsupported_nodes_str)
148151
else:
149152
logger.debug("\nAll Nodes Supported\n")
150153

py/torch_tensorrt/dynamo/backend/test/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def lower_graph_testing(
124124
torch_executed_ops: Sequence[str] = set(),
125125
testing_partitioning: bool = False,
126126
):
127-
"""Helper function to assist with graph lowering for testing of Dynamo torch_compile
127+
"""Helper function to assist with graph lowering for testing of Dynamo compile
128128
129129
Args:
130130
fx_graph: Graph to lower

py/torch_tensorrt/dynamo/common_utils/__init__.py

Whitespace-only changes.

py/torch_tensorrt/dynamo/test/test_dynamo_backend.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77

88
from transformers import BertModel
99

10-
from utils import COSINE_THRESHOLD, cosine_similarity
10+
from torch_tensorrt.dynamo.common_utils.test_utils import (
11+
COSINE_THRESHOLD,
12+
cosine_similarity,
13+
)
1114

1215

1316
@pytest.mark.unit
@@ -30,7 +33,7 @@ def test_resnet18(ir):
3033
cos_sim = cosine_similarity(model(input), trt_mod(input))
3134
assert (
3235
cos_sim > COSINE_THRESHOLD,
33-
f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
36+
f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
3437
)
3538

3639
# Clean up model env
@@ -163,7 +166,7 @@ def test_resnet18_half(ir):
163166
cos_sim = cosine_similarity(model(input), trt_mod(input))
164167
assert (
165168
cos_sim > COSINE_THRESHOLD,
166-
f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
169+
f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
167170
)
168171

169172
# Clean up model env

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def aten_ops_cat(
358358
) -> Union[TRTTensor, Sequence[TRTTensor]]:
359359
kwargs_new = {
360360
"tensors": args[0],
361-
"dim": args[1],
361+
"dim": args[1] if len(args) >= 2 else 0,
362362
}
363363
return acc_ops_converters.acc_ops_cat(network, target, None, kwargs_new, name)
364364

py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,41 @@ def forward(self, x, y):
5353
expected_ops={torch.ops.aten.cat.default},
5454
)
5555

56+
def test_cat_no_dim(self):
57+
class Cat(nn.Module):
58+
def forward(self, x, y, z):
59+
return torch.cat((x, y, z))
60+
61+
inputs = [torch.randn(2, 1, 3), torch.randn(1, 1, 3), torch.randn(3, 1, 3)]
62+
self.run_test(
63+
Cat(),
64+
inputs,
65+
expected_ops={torch.ops.aten.cat.default},
66+
)
67+
68+
def test_cat_dynamic_shape_no_dim(self):
69+
class Cat(nn.Module):
70+
def forward(self, x, y):
71+
return torch.cat((x, y))
72+
73+
input_specs = [
74+
InputTensorSpec(
75+
shape=(-1, 16, 3),
76+
dtype=torch.float32,
77+
shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))],
78+
),
79+
InputTensorSpec(
80+
shape=(-1, 16, 3),
81+
dtype=torch.float32,
82+
shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))],
83+
),
84+
]
85+
self.run_test_with_dynamic_shape(
86+
Cat(),
87+
input_specs,
88+
expected_ops={torch.ops.aten.cat.default},
89+
)
90+
5691

5792
if __name__ == "__main__":
5893
run_tests()

0 commit comments

Comments
 (0)