Skip to content

Commit 338de01

Browse files
committed
fix/feat: Add support for multiple TRT Build Args
- Add support in Dynamo path to bring build arg parity to 1-1 with TorchScript - Add tests for new build arguments - Modify TRT Interpreter and conversion phases accordingly
1 parent 2f9b259 commit 338de01

File tree

7 files changed

+216
-54
lines changed

7 files changed

+216
-54
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,21 @@
1616
from torch_tensorrt.dynamo._defaults import (
1717
DEBUG,
1818
DEVICE,
19+
DISABLE_TF32,
20+
DLA_GLOBAL_DRAM_SIZE,
21+
DLA_LOCAL_DRAM_SIZE,
22+
DLA_SRAM_SIZE,
1923
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
24+
ENGINE_CAPABILITY,
2025
MAX_AUX_STREAMS,
2126
MIN_BLOCK_SIZE,
27+
NUM_AVG_TIMING_ITERS,
2228
OPTIMIZATION_LEVEL,
2329
PASS_THROUGH_BUILD_FAILURES,
2430
PRECISION,
31+
REFIT,
2532
REQUIRE_FULL_COMPILATION,
33+
SPARSE_WEIGHTS,
2634
TRUNCATE_LONG_AND_DOUBLE,
2735
USE_FAST_PARTITIONER,
2836
USE_PYTHON_RUNTIME,
@@ -51,17 +59,18 @@ def compile(
5159
inputs: Tuple[Any, ...],
5260
*,
5361
device: Optional[Union[Device, torch.device, str]] = DEVICE,
54-
disable_tf32: bool = False,
55-
sparse_weights: bool = False,
62+
disable_tf32: bool = DISABLE_TF32,
63+
sparse_weights: bool = SPARSE_WEIGHTS,
5664
enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,),
57-
refit: bool = False,
65+
engine_capability: EngineCapability = ENGINE_CAPABILITY,
66+
refit: bool = REFIT,
5867
debug: bool = DEBUG,
5968
capability: EngineCapability = EngineCapability.default,
60-
num_avg_timing_iters: int = 1,
69+
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS,
6170
workspace_size: int = WORKSPACE_SIZE,
62-
dla_sram_size: int = 1048576,
63-
dla_local_dram_size: int = 1073741824,
64-
dla_global_dram_size: int = 536870912,
71+
dla_sram_size: int = DLA_SRAM_SIZE,
72+
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE,
73+
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE,
6574
calibrator: object = None,
6675
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
6776
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
@@ -192,6 +201,13 @@ def compile(
192201
"use_fast_partitioner": use_fast_partitioner,
193202
"enable_experimental_decompositions": enable_experimental_decompositions,
194203
"require_full_compilation": require_full_compilation,
204+
"disable_tf32": disable_tf32,
205+
"sparse_weights": sparse_weights,
206+
"refit": refit,
207+
"engine_capability": engine_capability,
208+
"dla_sram_size": dla_sram_size,
209+
"dla_local_dram_size": dla_local_dram_size,
210+
"dla_global_dram_size": dla_global_dram_size,
195211
}
196212

197213
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
11
import torch
2+
from tensorrt import EngineCapability
23
from torch_tensorrt._Device import Device
34

45
PRECISION = torch.float32
56
DEBUG = False
67
DEVICE = None
8+
DISABLE_TF32 = False
9+
DLA_LOCAL_DRAM_SIZE = 1073741824
10+
DLA_GLOBAL_DRAM_SIZE = 536870912
11+
DLA_SRAM_SIZE = 1048576
12+
ENGINE_CAPABILITY = EngineCapability.STANDARD
713
WORKSPACE_SIZE = 0
814
MIN_BLOCK_SIZE = 5
915
PASS_THROUGH_BUILD_FAILURES = False
1016
MAX_AUX_STREAMS = None
17+
NUM_AVG_TIMING_ITERS = 1
1118
VERSION_COMPATIBLE = False
1219
OPTIMIZATION_LEVEL = None
20+
SPARSE_WEIGHTS = False
1321
TRUNCATE_LONG_AND_DOUBLE = False
1422
USE_PYTHON_RUNTIME = False
1523
USE_FAST_PARTITIONER = True
1624
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
25+
REFIT = False
1726
REQUIRE_FULL_COMPILATION = False
1827

1928

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,25 @@
22
from typing import Optional, Set
33

44
import torch
5+
from tensorrt import EngineCapability
56
from torch_tensorrt._Device import Device
67
from torch_tensorrt.dynamo._defaults import (
78
DEBUG,
9+
DISABLE_TF32,
10+
DLA_GLOBAL_DRAM_SIZE,
11+
DLA_LOCAL_DRAM_SIZE,
12+
DLA_SRAM_SIZE,
813
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
14+
ENGINE_CAPABILITY,
915
MAX_AUX_STREAMS,
1016
MIN_BLOCK_SIZE,
17+
NUM_AVG_TIMING_ITERS,
1118
OPTIMIZATION_LEVEL,
1219
PASS_THROUGH_BUILD_FAILURES,
1320
PRECISION,
21+
REFIT,
1422
REQUIRE_FULL_COMPILATION,
23+
SPARSE_WEIGHTS,
1524
TRUNCATE_LONG_AND_DOUBLE,
1625
USE_FAST_PARTITIONER,
1726
USE_PYTHON_RUNTIME,
@@ -46,6 +55,14 @@ class CompilationSettings:
4655
device (Device): GPU to compile the model on
4756
require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT.
4857
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
58+
disable_tf32 (bool): Whether to disable TF32 computation for TRT layers
59+
sparse_weights (bool): Whether to allow the builder to use sparse weights
60+
refit (bool): Whether to build a refittable engine
61+
engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
62+
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
63+
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
64+
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
65+
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
4966
"""
5067

5168
precision: torch.dtype = PRECISION
@@ -63,3 +80,11 @@ class CompilationSettings:
6380
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
6481
device: Device = field(default_factory=default_device)
6582
require_full_compilation: bool = REQUIRE_FULL_COMPILATION
83+
disable_tf32: bool = DISABLE_TF32
84+
sparse_weights: bool = SPARSE_WEIGHTS
85+
refit: bool = REFIT
86+
engine_capability: EngineCapability = ENGINE_CAPABILITY
87+
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS
88+
dla_sram_size: int = DLA_SRAM_SIZE
89+
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE
90+
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
55

66
import numpy as np
7+
8+
# @manual=//deeplearning/trt/python:py_tensorrt
9+
import tensorrt as trt
710
import torch
811
import torch.fx
912
from torch.fx.node import _get_qualified_name
@@ -23,8 +26,6 @@
2326
from torch_tensorrt.fx.observer import Observer
2427
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
2528

26-
# @manual=//deeplearning/trt/python:py_tensorrt
27-
import tensorrt as trt
2829
from packaging import version
2930

3031
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -96,6 +97,7 @@ def __init__(
9697
self._itensor_to_tensor_meta: Dict[
9798
trt.tensorrt.ITensor, TensorMetadata
9899
] = dict()
100+
self.compilation_settings = compilation_settings
99101

100102
# Data types for TRT Module output Tensors
101103
self.output_dtypes = output_dtypes
@@ -118,40 +120,25 @@ def validate_conversion(self) -> Set[str]:
118120

119121
def run(
120122
self,
121-
workspace_size: int = 0,
122-
precision: torch.dtype = torch.float32, # TODO: @peri044 Needs to be expanded to set
123-
sparse_weights: bool = False,
124-
disable_tf32: bool = False,
125123
force_fp32_output: bool = False,
126124
strict_type_constraints: bool = False,
127125
algorithm_selector: Optional[trt.IAlgorithmSelector] = None,
128126
timing_cache: Optional[trt.ITimingCache] = None,
129-
profiling_verbosity: Optional[trt.ProfilingVerbosity] = None,
130127
tactic_sources: Optional[int] = None,
131-
max_aux_streams: Optional[int] = None,
132-
version_compatible: bool = False,
133-
optimization_level: Optional[int] = None,
134128
) -> TRTInterpreterResult:
135129
"""
136130
Build TensorRT engine with some configs.
137131
Args:
138-
workspace_size: Amount of memory used by TensorRT to store intermediate buffers within an operation.
139-
precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
140-
sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
141132
force_fp32_output: force output to be fp32
142133
strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
143134
algorithm_selector: set up algorithm selection for certain layer
144135
timing_cache: enable timing cache for TensorRT
145-
profiling_verbosity: TensorRT logging level
146-
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
147-
version_compatible: Provide version forward-compatibility for engine plan files
148-
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
149-
searching for more optimization options. TRT defaults to 3
150136
Return:
151137
TRTInterpreterResult
152138
"""
153139
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)
154140

141+
precision = self.compilation_settings.precision
155142
# For float outputs, we set their dtype to fp16 only if precision == torch.float16 and
156143
# force_fp32_output=False. Overriden by specifying output_dtypes
157144
self.output_fp16 = not force_fp32_output and precision == torch.float16
@@ -172,9 +159,9 @@ def run(
172159

173160
builder_config = self.builder.create_builder_config()
174161

175-
if workspace_size != 0:
162+
if self.compilation_settings.workspace_size != 0:
176163
builder_config.set_memory_pool_limit(
177-
trt.MemoryPoolType.WORKSPACE, workspace_size
164+
trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size
178165
)
179166

180167
cache = None
@@ -187,34 +174,66 @@ def run(
187174

188175
if version.parse(trt.__version__) >= version.parse("8.2"):
189176
builder_config.profiling_verbosity = (
190-
profiling_verbosity
191-
if profiling_verbosity
177+
trt.ProfilingVerbosity.VERBOSE
178+
if self.compilation_settings.debug
192179
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
193180
)
194181

195182
if version.parse(trt.__version__) >= version.parse("8.6"):
196-
if max_aux_streams is not None:
197-
_LOGGER.info(f"Setting max aux streams to {max_aux_streams}")
198-
builder_config.max_aux_streams = max_aux_streams
199-
if version_compatible:
183+
if self.compilation_settings.max_aux_streams is not None:
184+
_LOGGER.info(
185+
f"Setting max aux streams to {self.compilation_settings.max_aux_streams}"
186+
)
187+
builder_config.max_aux_streams = (
188+
self.compilation_settings.max_aux_streams
189+
)
190+
if self.compilation_settings.version_compatible:
200191
_LOGGER.info("Using version compatible")
201192
builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
202-
if optimization_level is not None:
203-
_LOGGER.info(f"Using optimization level {optimization_level}")
204-
builder_config.builder_optimization_level = optimization_level
193+
if self.compilation_settings.optimization_level is not None:
194+
_LOGGER.info(
195+
f"Using optimization level {self.compilation_settings.optimization_level}"
196+
)
197+
builder_config.builder_optimization_level = (
198+
self.compilation_settings.optimization_level
199+
)
200+
201+
builder_config.engine_capability = self.compilation_settings.engine_capability
202+
builder_config.avg_timing_iterations = (
203+
self.compilation_settings.num_avg_timing_iters
204+
)
205+
206+
if self.compilation_settings.device.device_type == trt.DeviceType.DLA:
207+
builder_config.DLA_core = self.compilation_settings.device.dla_core
208+
_LOGGER.info(f"Using DLA core {self.compilation_settings.device.dla_core}")
209+
builder_config.set_memory_pool_limit(
210+
trt.MemoryPoolType.DLA_MANAGED_SRAM,
211+
self.compilation_settings.dla_sram_size,
212+
)
213+
builder_config.set_memory_pool_limit(
214+
trt.MemoryPoolType.DLA_LOCAL_DRAM,
215+
self.compilation_settings.dla_local_dram_size,
216+
)
217+
builder_config.set_memory_pool_limit(
218+
trt.MemoryPoolType.DLA_GLOBAL_DRAM,
219+
self.compilation_settings.dla_global_dram_size,
220+
)
205221

206222
if precision == torch.float16:
207223
builder_config.set_flag(trt.BuilderFlag.FP16)
208224

209225
if precision == torch.int8:
210226
builder_config.set_flag(trt.BuilderFlag.INT8)
211227

212-
if sparse_weights:
228+
if self.compilation_settings.sparse_weights:
213229
builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)
214230

215-
if disable_tf32:
231+
if self.compilation_settings.disable_tf32:
216232
builder_config.clear_flag(trt.BuilderFlag.TF32)
217233

234+
if self.compilation_settings.refit:
235+
builder_config.set_flag(trt.BuilderFlag.REFIT)
236+
218237
if strict_type_constraints:
219238
builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES)
220239

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33
import io
44
from typing import Sequence
55

6+
import tensorrt as trt
67
import torch
78
from torch_tensorrt._Input import Input
89
from torch_tensorrt.dynamo._settings import CompilationSettings
910
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
1011
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1112
from torch_tensorrt.dynamo.utils import get_torch_inputs
1213

13-
import tensorrt as trt
14-
1514

1615
def convert_module(
1716
module: torch.fx.GraphModule,
@@ -54,18 +53,7 @@ def convert_module(
5453
output_dtypes=output_dtypes,
5554
compilation_settings=settings,
5655
)
57-
interpreter_result = interpreter.run(
58-
workspace_size=settings.workspace_size,
59-
precision=settings.precision,
60-
profiling_verbosity=(
61-
trt.ProfilingVerbosity.VERBOSE
62-
if settings.debug
63-
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
64-
),
65-
max_aux_streams=settings.max_aux_streams,
66-
version_compatible=settings.version_compatible,
67-
optimization_level=settings.optimization_level,
68-
)
56+
interpreter_result = interpreter.run()
6957

7058
if settings.use_python_runtime:
7159
return PythonTorchTensorRTModule(

tests/py/dynamo/conversion/harness.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def run_test(
6060

6161
mod.eval()
6262
start = time.perf_counter()
63-
interpreter_result = interpreter.run(precision=precision)
63+
interpreter_result = interpreter.run()
6464
sec = time.perf_counter() - start
6565
_LOGGER.info(f"Interpreter run time(s): {sec}")
6666
trt_mod = PythonTorchTensorRTModule(
@@ -234,7 +234,9 @@ def run_test(
234234

235235
# Previous instance of the interpreter auto-casted 64-bit inputs
236236
# We replicate this behavior here
237-
compilation_settings = CompilationSettings(truncate_long_and_double=True)
237+
compilation_settings = CompilationSettings(
238+
precision=precision, truncate_long_and_double=True
239+
)
238240

239241
interp = TRTInterpreter(
240242
mod,

0 commit comments

Comments
 (0)