4
4
from typing import Any , Callable , Dict , List , NamedTuple , Optional , Sequence , Set
5
5
6
6
import numpy as np
7
+
8
+ # @manual=//deeplearning/trt/python:py_tensorrt
9
+ import tensorrt as trt
7
10
import torch
8
11
import torch .fx
9
12
from torch .fx .node import _get_qualified_name
23
26
from torch_tensorrt .fx .observer import Observer
24
27
from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
25
28
26
- # @manual=//deeplearning/trt/python:py_tensorrt
27
- import tensorrt as trt
28
29
from packaging import version
29
30
30
31
_LOGGER : logging .Logger = logging .getLogger (__name__ )
@@ -96,6 +97,7 @@ def __init__(
96
97
self ._itensor_to_tensor_meta : Dict [
97
98
trt .tensorrt .ITensor , TensorMetadata
98
99
] = dict ()
100
+ self .compilation_settings = compilation_settings
99
101
100
102
# Data types for TRT Module output Tensors
101
103
self .output_dtypes = output_dtypes
@@ -118,40 +120,25 @@ def validate_conversion(self) -> Set[str]:
118
120
119
121
def run (
120
122
self ,
121
- workspace_size : int = 0 ,
122
- precision : torch .dtype = torch .float32 , # TODO: @peri044 Needs to be expanded to set
123
- sparse_weights : bool = False ,
124
- disable_tf32 : bool = False ,
125
123
force_fp32_output : bool = False ,
126
124
strict_type_constraints : bool = False ,
127
125
algorithm_selector : Optional [trt .IAlgorithmSelector ] = None ,
128
126
timing_cache : Optional [trt .ITimingCache ] = None ,
129
- profiling_verbosity : Optional [trt .ProfilingVerbosity ] = None ,
130
127
tactic_sources : Optional [int ] = None ,
131
- max_aux_streams : Optional [int ] = None ,
132
- version_compatible : bool = False ,
133
- optimization_level : Optional [int ] = None ,
134
128
) -> TRTInterpreterResult :
135
129
"""
136
130
Build TensorRT engine with some configs.
137
131
Args:
138
- workspace_size: Amount of memory used by TensorRT to store intermediate buffers within an operation.
139
- precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
140
- sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
141
132
force_fp32_output: force output to be fp32
142
133
strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
143
134
algorithm_selector: set up algorithm selection for certain layer
144
135
timing_cache: enable timing cache for TensorRT
145
- profiling_verbosity: TensorRT logging level
146
- max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
147
- version_compatible: Provide version forward-compatibility for engine plan files
148
- optimization_level: Builder optimization 0-5, higher levels imply longer build time,
149
- searching for more optimization options. TRT defaults to 3
150
136
Return:
151
137
TRTInterpreterResult
152
138
"""
153
139
TRT_INTERPRETER_CALL_PRE_OBSERVER .observe (self .module )
154
140
141
+ precision = self .compilation_settings .precision
155
142
# For float outputs, we set their dtype to fp16 only if precision == torch.float16 and
156
143
# force_fp32_output=False. Overriden by specifying output_dtypes
157
144
self .output_fp16 = not force_fp32_output and precision == torch .float16
@@ -172,9 +159,9 @@ def run(
172
159
173
160
builder_config = self .builder .create_builder_config ()
174
161
175
- if workspace_size != 0 :
162
+ if self . compilation_settings . workspace_size != 0 :
176
163
builder_config .set_memory_pool_limit (
177
- trt .MemoryPoolType .WORKSPACE , workspace_size
164
+ trt .MemoryPoolType .WORKSPACE , self . compilation_settings . workspace_size
178
165
)
179
166
180
167
cache = None
@@ -187,34 +174,66 @@ def run(
187
174
188
175
if version .parse (trt .__version__ ) >= version .parse ("8.2" ):
189
176
builder_config .profiling_verbosity = (
190
- profiling_verbosity
191
- if profiling_verbosity
177
+ trt . ProfilingVerbosity . VERBOSE
178
+ if self . compilation_settings . debug
192
179
else trt .ProfilingVerbosity .LAYER_NAMES_ONLY
193
180
)
194
181
195
182
if version .parse (trt .__version__ ) >= version .parse ("8.6" ):
196
- if max_aux_streams is not None :
197
- _LOGGER .info (f"Setting max aux streams to { max_aux_streams } " )
198
- builder_config .max_aux_streams = max_aux_streams
199
- if version_compatible :
183
+ if self .compilation_settings .max_aux_streams is not None :
184
+ _LOGGER .info (
185
+ f"Setting max aux streams to { self .compilation_settings .max_aux_streams } "
186
+ )
187
+ builder_config .max_aux_streams = (
188
+ self .compilation_settings .max_aux_streams
189
+ )
190
+ if self .compilation_settings .version_compatible :
200
191
_LOGGER .info ("Using version compatible" )
201
192
builder_config .set_flag (trt .BuilderFlag .VERSION_COMPATIBLE )
202
- if optimization_level is not None :
203
- _LOGGER .info (f"Using optimization level { optimization_level } " )
204
- builder_config .builder_optimization_level = optimization_level
193
+ if self .compilation_settings .optimization_level is not None :
194
+ _LOGGER .info (
195
+ f"Using optimization level { self .compilation_settings .optimization_level } "
196
+ )
197
+ builder_config .builder_optimization_level = (
198
+ self .compilation_settings .optimization_level
199
+ )
200
+
201
+ builder_config .engine_capability = self .compilation_settings .engine_capability
202
+ builder_config .avg_timing_iterations = (
203
+ self .compilation_settings .num_avg_timing_iters
204
+ )
205
+
206
+ if self .compilation_settings .device .device_type == trt .DeviceType .DLA :
207
+ builder_config .DLA_core = self .compilation_settings .device .dla_core
208
+ _LOGGER .info (f"Using DLA core { self .compilation_settings .device .dla_core } " )
209
+ builder_config .set_memory_pool_limit (
210
+ trt .MemoryPoolType .DLA_MANAGED_SRAM ,
211
+ self .compilation_settings .dla_sram_size ,
212
+ )
213
+ builder_config .set_memory_pool_limit (
214
+ trt .MemoryPoolType .DLA_LOCAL_DRAM ,
215
+ self .compilation_settings .dla_local_dram_size ,
216
+ )
217
+ builder_config .set_memory_pool_limit (
218
+ trt .MemoryPoolType .DLA_GLOBAL_DRAM ,
219
+ self .compilation_settings .dla_global_dram_size ,
220
+ )
205
221
206
222
if precision == torch .float16 :
207
223
builder_config .set_flag (trt .BuilderFlag .FP16 )
208
224
209
225
if precision == torch .int8 :
210
226
builder_config .set_flag (trt .BuilderFlag .INT8 )
211
227
212
- if sparse_weights :
228
+ if self . compilation_settings . sparse_weights :
213
229
builder_config .set_flag (trt .BuilderFlag .SPARSE_WEIGHTS )
214
230
215
- if disable_tf32 :
231
+ if self . compilation_settings . disable_tf32 :
216
232
builder_config .clear_flag (trt .BuilderFlag .TF32 )
217
233
234
+ if self .compilation_settings .refit :
235
+ builder_config .set_flag (trt .BuilderFlag .REFIT )
236
+
218
237
if strict_type_constraints :
219
238
builder_config .set_flag (trt .BuilderFlag .STRICT_TYPES )
220
239
0 commit comments