Skip to content

Commit cb2be86

Browse files
committed
fix: Unify export/compile compilation utilities
- Add support for new TRT 8.6 utilities, including auxiliary streams, version compatibility, and optimization levels - Add support for TRTModuleNext use during compilation with Dynamo compile - Improve documentation of features and version checking for TRT feature compatibility
1 parent dd31c9a commit cb2be86

File tree

6 files changed

+68
-9
lines changed

6 files changed

+68
-9
lines changed

py/torch_tensorrt/dynamo/backend/__init__.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch_tensorrt
55
from functools import partial
66

7-
from typing import Any, Sequence
7+
from typing import Any, Optional, Sequence
88
from torch_tensorrt import EngineCapability, Device
99
from torch_tensorrt.fx.utils import LowerPrecision
1010

@@ -17,6 +17,10 @@
1717
MAX_WORKSPACE_SIZE,
1818
MIN_BLOCK_SIZE,
1919
PASS_THROUGH_BUILD_FAILURES,
20+
MAX_AUX_STREAMS,
21+
VERSION_COMPATIBLE,
22+
OPTIMIZATION_LEVEL,
23+
USE_EXPERIMENTAL_RT,
2024
)
2125

2226

@@ -45,6 +49,10 @@ def compile(
4549
min_block_size=MIN_BLOCK_SIZE,
4650
torch_executed_ops=[],
4751
torch_executed_modules=[],
52+
max_aux_streams=MAX_AUX_STREAMS,
53+
version_compatible=VERSION_COMPATIBLE,
54+
optimization_level=OPTIMIZATION_LEVEL,
55+
use_experimental_rt=USE_EXPERIMENTAL_RT,
4856
**kwargs,
4957
):
5058
if debug:
@@ -86,6 +94,10 @@ def compile(
8694
workspace_size=workspace_size,
8795
min_block_size=min_block_size,
8896
torch_executed_ops=torch_executed_ops,
97+
max_aux_streams=max_aux_streams,
98+
version_compatible=version_compatible,
99+
optimization_level=optimization_level,
100+
use_experimental_rt=use_experimental_rt,
89101
**kwargs,
90102
)
91103

@@ -109,6 +121,10 @@ def create_backend(
109121
min_block_size: int = MIN_BLOCK_SIZE,
110122
torch_executed_ops: Sequence[str] = set(),
111123
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
124+
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
125+
version_compatible: bool = VERSION_COMPATIBLE,
126+
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
127+
use_experimental_rt: bool = USE_EXPERIMENTAL_RT,
112128
**kwargs,
113129
):
114130
"""Create torch.compile backend given specified arguments
@@ -117,7 +133,14 @@ def create_backend(
117133
precision:
118134
debug: Whether to print out verbose debugging information
119135
workspace_size: Maximum workspace TRT is allowed to use for the module
120-
precision: Model Layer precision
136+
min_block_size: Minimum number of operators per TRT-Engine Block
137+
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
138+
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
139+
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
140+
version_compatible: Provide version forward-compatibility for engine plan files
141+
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
142+
searching for more optimization options. TRT defaults to 3
143+
use_experimental_rt: Whether to use the new experimental TRTModuleNext for TRT engines
121144
Returns:
122145
Backend for torch.compile
123146
"""
@@ -131,6 +154,10 @@ def create_backend(
131154
min_block_size=min_block_size,
132155
torch_executed_ops=torch_executed_ops,
133156
pass_through_build_failures=pass_through_build_failures,
157+
max_aux_streams=max_aux_streams,
158+
version_compatible=version_compatible,
159+
optimization_level=optimization_level,
160+
use_experimental_rt=use_experimental_rt,
134161
)
135162

136163
return partial(

py/torch_tensorrt/dynamo/backend/_defaults.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,7 @@
66
MAX_WORKSPACE_SIZE = 20 << 30
77
MIN_BLOCK_SIZE = 5
88
PASS_THROUGH_BUILD_FAILURES = False
9+
MAX_AUX_STREAMS = None
10+
VERSION_COMPATIBLE = False
11+
OPTIMIZATION_LEVEL = None
12+
USE_EXPERIMENTAL_RT = False

py/torch_tensorrt/dynamo/backend/_settings.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Sequence
2+
from typing import Optional, Sequence
33

44
from torch_tensorrt.fx.utils import LowerPrecision
55
from torch_tensorrt.dynamo.backend._defaults import (
@@ -8,6 +8,10 @@
88
MAX_WORKSPACE_SIZE,
99
MIN_BLOCK_SIZE,
1010
PASS_THROUGH_BUILD_FAILURES,
11+
MAX_AUX_STREAMS,
12+
VERSION_COMPATIBLE,
13+
OPTIMIZATION_LEVEL,
14+
USE_EXPERIMENTAL_RT,
1115
)
1216

1317

@@ -19,3 +23,7 @@ class CompilationSettings:
1923
min_block_size: int = MIN_BLOCK_SIZE
2024
torch_executed_ops: Sequence[str] = field(default_factory=set)
2125
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
26+
max_aux_streams: Optional[int] = MAX_AUX_STREAMS
27+
version_compatible: bool = VERSION_COMPATIBLE
28+
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
29+
use_experimental_rt: bool = USE_EXPERIMENTAL_RT

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def _compile_module(
135135
submodule,
136136
submodule_inputs,
137137
settings=settings,
138+
name=name,
138139
)
139140

140141
# Replace FX Module with TRT Module

py/torch_tensorrt/dynamo/backend/conversion.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch_tensorrt.fx.trt_module import TRTModule
44
from torch_tensorrt import TRTModuleNext
55
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
6-
from torch_tensorrt.fx.fx2trt import (
6+
from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import (
77
InputTensorSpec,
88
TRTInterpreter,
99
)
@@ -15,30 +15,48 @@ def convert_module(
1515
module: torch.fx.GraphModule,
1616
inputs: Sequence[torch.Tensor],
1717
settings: CompilationSettings = CompilationSettings(),
18+
name: str = "",
1819
) -> Union[TRTModuleNext, TRTModule]:
1920
"""Convert an FX module to a TRT module
2021
Args:
2122
module: FX GraphModule to convert
2223
inputs: Sequence of Tensors representing inputs to the module
2324
settings: Compilation settings
25+
name: TRT engine name
2426
Returns:
2527
TRTModule or TRTModuleNext
2628
"""
27-
interp = TRTInterpreter(
29+
interpreter = TRTInterpreter(
2830
module,
2931
InputTensorSpec.from_tensors(inputs),
3032
explicit_batch_dimension=True,
3133
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
3234
)
3335

34-
r = interp.run(
36+
interpreter_result = interpreter.run(
3537
max_workspace_size=settings.workspace_size,
3638
lower_precision=settings.precision,
3739
profiling_verbosity=(
3840
trt.ProfilingVerbosity.VERBOSE
3941
if settings.debug
4042
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
4143
),
44+
max_aux_streams=settings.max_aux_streams,
45+
version_compatible=settings.version_compatible,
46+
optimization_level=settings.optimization_level,
4247
)
4348

44-
return TRTModule(*r)
49+
return (
50+
TRTModuleNext(
51+
serialized_engine=interpreter_result.engine,
52+
name=name,
53+
input_binding_names=interpreter_result.input_names,
54+
output_binding_names=interpreter_result.output_names,
55+
)
56+
if settings.use_experimental_rt
57+
else TRTModule(
58+
engine=interpreter_result.engine,
59+
input_names=interpreter_result.input_names,
60+
output_names=interpreter_result.output_names,
61+
)
62+
)

py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import warnings
33
from datetime import datetime
4+
from packaging import version
45
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
56

67
import numpy
@@ -224,14 +225,14 @@ def run(
224225
cache = builder_config.create_timing_cache(b"")
225226
builder_config.set_timing_cache(cache, False)
226227

227-
if trt.__version__ >= "8.2":
228+
if version.parse(trt.__version__) >= version.parse("8.2"):
228229
builder_config.profiling_verbosity = (
229230
profiling_verbosity
230231
if profiling_verbosity
231232
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
232233
)
233234

234-
if trt.__version__ >= "8.6":
235+
if version.parse(trt.__version__) >= version.parse("8.6"):
235236
if max_aux_streams is not None:
236237
_LOGGER.info(f"Setting max aux streams to {max_aux_streams}")
237238
builder_config.max_aux_streams = max_aux_streams

0 commit comments

Comments
 (0)