Skip to content

Commit c46f614

Browse files
committed
fix: Unify export/compile compilation utilities
- Update argument naming for compatibility with `fx_ts_compat` utilities - Add support for new TRT 8.6 utilities, including auxiliary streams, version compatibility, and optimization levels - Add support for TRTModuleNext use during compilation with Dynamo compile - Improve documentation of features and version checking for TRT feature compatibility
1 parent dd31c9a commit c46f614

File tree

6 files changed

+76
-17
lines changed

6 files changed

+76
-17
lines changed

py/torch_tensorrt/dynamo/backend/__init__.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch_tensorrt
55
from functools import partial
66

7-
from typing import Any, Sequence
7+
from typing import Any, Optional, Sequence
88
from torch_tensorrt import EngineCapability, Device
99
from torch_tensorrt.fx.utils import LowerPrecision
1010

@@ -14,9 +14,13 @@
1414
from torch_tensorrt.dynamo.backend._defaults import (
1515
PRECISION,
1616
DEBUG,
17-
MAX_WORKSPACE_SIZE,
17+
WORKSPACE_SIZE,
1818
MIN_BLOCK_SIZE,
1919
PASS_THROUGH_BUILD_FAILURES,
20+
MAX_AUX_STREAMS,
21+
VERSION_COMPATIBLE,
22+
OPTIMIZATION_LEVEL,
23+
USE_EXPERIMENTAL_RT,
2024
)
2125

2226

@@ -35,7 +39,7 @@ def compile(
3539
debug=DEBUG,
3640
capability=EngineCapability.default,
3741
num_avg_timing_iters=1,
38-
workspace_size=MAX_WORKSPACE_SIZE,
42+
workspace_size=WORKSPACE_SIZE,
3943
dla_sram_size=1048576,
4044
dla_local_dram_size=1073741824,
4145
dla_global_dram_size=536870912,
@@ -45,6 +49,10 @@ def compile(
4549
min_block_size=MIN_BLOCK_SIZE,
4650
torch_executed_ops=[],
4751
torch_executed_modules=[],
52+
max_aux_streams=MAX_AUX_STREAMS,
53+
version_compatible=VERSION_COMPATIBLE,
54+
optimization_level=OPTIMIZATION_LEVEL,
55+
use_experimental_rt=USE_EXPERIMENTAL_RT,
4856
**kwargs,
4957
):
5058
if debug:
@@ -86,6 +94,10 @@ def compile(
8694
workspace_size=workspace_size,
8795
min_block_size=min_block_size,
8896
torch_executed_ops=torch_executed_ops,
97+
max_aux_streams=max_aux_streams,
98+
version_compatible=version_compatible,
99+
optimization_level=optimization_level,
100+
use_experimental_rt=use_experimental_rt,
89101
**kwargs,
90102
)
91103

@@ -105,19 +117,30 @@ def compile(
105117
def create_backend(
106118
precision: LowerPrecision = PRECISION,
107119
debug: bool = DEBUG,
108-
workspace_size: int = MAX_WORKSPACE_SIZE,
120+
workspace_size: int = WORKSPACE_SIZE,
109121
min_block_size: int = MIN_BLOCK_SIZE,
110122
torch_executed_ops: Sequence[str] = set(),
111123
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
124+
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
125+
version_compatible: bool = VERSION_COMPATIBLE,
126+
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
127+
use_experimental_rt: bool = USE_EXPERIMENTAL_RT,
112128
**kwargs,
113129
):
114130
"""Create torch.compile backend given specified arguments
115131
116132
Args:
117133
precision:
118134
debug: Whether to print out verbose debugging information
119-
workspace_size: Maximum workspace TRT is allowed to use for the module
120-
precision: Model Layer precision
135+
workspace_size: Workspace TRT is allowed to use for the module (0 is default)
136+
min_block_size: Minimum number of operators per TRT-Engine Block
137+
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
138+
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
139+
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
140+
version_compatible: Provide version forward-compatibility for engine plan files
141+
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
142+
searching for more optimization options. TRT defaults to 3
143+
use_experimental_rt: Whether to use the new experimental TRTModuleNext for TRT engines
121144
Returns:
122145
Backend for torch.compile
123146
"""
@@ -131,6 +154,10 @@ def create_backend(
131154
min_block_size=min_block_size,
132155
torch_executed_ops=torch_executed_ops,
133156
pass_through_build_failures=pass_through_build_failures,
157+
max_aux_streams=max_aux_streams,
158+
version_compatible=version_compatible,
159+
optimization_level=optimization_level,
160+
use_experimental_rt=use_experimental_rt,
134161
)
135162

136163
return partial(

py/torch_tensorrt/dynamo/backend/_defaults.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
PRECISION = LowerPrecision.FP32
55
DEBUG = False
6-
MAX_WORKSPACE_SIZE = 20 << 30
6+
WORKSPACE_SIZE = 0
77
MIN_BLOCK_SIZE = 5
88
PASS_THROUGH_BUILD_FAILURES = False
9+
MAX_AUX_STREAMS = None
10+
VERSION_COMPATIBLE = False
11+
OPTIMIZATION_LEVEL = None
12+
USE_EXPERIMENTAL_RT = False
Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
11
from dataclasses import dataclass, field
2-
from typing import Sequence
2+
from typing import Optional, Sequence
33

44
from torch_tensorrt.fx.utils import LowerPrecision
55
from torch_tensorrt.dynamo.backend._defaults import (
66
PRECISION,
77
DEBUG,
8-
MAX_WORKSPACE_SIZE,
8+
WORKSPACE_SIZE,
99
MIN_BLOCK_SIZE,
1010
PASS_THROUGH_BUILD_FAILURES,
11+
MAX_AUX_STREAMS,
12+
VERSION_COMPATIBLE,
13+
OPTIMIZATION_LEVEL,
14+
USE_EXPERIMENTAL_RT,
1115
)
1216

1317

1418
@dataclass(frozen=True)
1519
class CompilationSettings:
1620
precision: LowerPrecision = PRECISION
1721
debug: bool = DEBUG
18-
workspace_size: int = MAX_WORKSPACE_SIZE
22+
workspace_size: int = WORKSPACE_SIZE
1923
min_block_size: int = MIN_BLOCK_SIZE
2024
torch_executed_ops: Sequence[str] = field(default_factory=set)
2125
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
26+
max_aux_streams: Optional[int] = MAX_AUX_STREAMS
27+
version_compatible: bool = VERSION_COMPATIBLE
28+
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
29+
use_experimental_rt: bool = USE_EXPERIMENTAL_RT

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def _compile_module(
135135
submodule,
136136
submodule_inputs,
137137
settings=settings,
138+
name=name,
138139
)
139140

140141
# Replace FX Module with TRT Module

py/torch_tensorrt/dynamo/backend/conversion.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch_tensorrt.fx.trt_module import TRTModule
44
from torch_tensorrt import TRTModuleNext
55
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
6-
from torch_tensorrt.fx.fx2trt import (
6+
from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import (
77
InputTensorSpec,
88
TRTInterpreter,
99
)
@@ -15,30 +15,48 @@ def convert_module(
1515
module: torch.fx.GraphModule,
1616
inputs: Sequence[torch.Tensor],
1717
settings: CompilationSettings = CompilationSettings(),
18+
name: str = "",
1819
) -> Union[TRTModuleNext, TRTModule]:
1920
"""Convert an FX module to a TRT module
2021
Args:
2122
module: FX GraphModule to convert
2223
inputs: Sequence of Tensors representing inputs to the module
2324
settings: Compilation settings
25+
name: TRT engine name
2426
Returns:
2527
TRTModule or TRTModuleNext
2628
"""
27-
interp = TRTInterpreter(
29+
interpreter = TRTInterpreter(
2830
module,
2931
InputTensorSpec.from_tensors(inputs),
3032
explicit_batch_dimension=True,
3133
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
3234
)
3335

34-
r = interp.run(
35-
max_workspace_size=settings.workspace_size,
36+
interpreter_result = interpreter.run(
37+
workspace_size=settings.workspace_size,
3638
lower_precision=settings.precision,
3739
profiling_verbosity=(
3840
trt.ProfilingVerbosity.VERBOSE
3941
if settings.debug
4042
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
4143
),
44+
max_aux_streams=settings.max_aux_streams,
45+
version_compatible=settings.version_compatible,
46+
optimization_level=settings.optimization_level,
4247
)
4348

44-
return TRTModule(*r)
49+
return (
50+
TRTModuleNext(
51+
serialized_engine=interpreter_result.engine,
52+
name=name,
53+
input_binding_names=interpreter_result.input_names,
54+
output_binding_names=interpreter_result.output_names,
55+
)
56+
if settings.use_experimental_rt
57+
else TRTModule(
58+
engine=interpreter_result.engine,
59+
input_names=interpreter_result.input_names,
60+
output_names=interpreter_result.output_names,
61+
)
62+
)

py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import warnings
33
from datetime import datetime
4+
from packaging import version
45
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
56

67
import numpy
@@ -224,14 +225,14 @@ def run(
224225
cache = builder_config.create_timing_cache(b"")
225226
builder_config.set_timing_cache(cache, False)
226227

227-
if trt.__version__ >= "8.2":
228+
if version.parse(trt.__version__) >= version.parse("8.2"):
228229
builder_config.profiling_verbosity = (
229230
profiling_verbosity
230231
if profiling_verbosity
231232
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
232233
)
233234

234-
if trt.__version__ >= "8.6":
235+
if version.parse(trt.__version__) >= version.parse("8.6"):
235236
if max_aux_streams is not None:
236237
_LOGGER.info(f"Setting max aux streams to {max_aux_streams}")
237238
builder_config.max_aux_streams = max_aux_streams

0 commit comments

Comments
 (0)