2
2
import warnings
3
3
from datetime import datetime
4
4
from packaging import version
5
- from typing import Any , Callable , Dict , List , NamedTuple , Optional , Sequence
5
+ from typing import Any , Callable , Dict , List , NamedTuple , Optional , Sequence , Set
6
6
7
7
import numpy
8
8
@@ -41,12 +41,13 @@ def __init__(
41
41
self ,
42
42
module : torch .fx .GraphModule ,
43
43
input_specs : List [Input ],
44
- logger_level = None ,
45
- output_dtypes = None ,
44
+ logger_level : trt . ILogger . Severity = trt . ILogger . Severity . WARNING ,
45
+ output_dtypes : Optional [ List [ torch . dtype ]] = None ,
46
46
):
47
47
super ().__init__ (module )
48
48
49
- self .logger = trt .Logger (logger_level or trt .Logger .WARNING )
49
+ # TODO: @narendasan replace with Torch-TensorRT Logger
50
+ self .logger = trt .Logger (logger_level )
50
51
self .builder = trt .Builder (self .logger )
51
52
52
53
flag = 0
@@ -59,12 +60,13 @@ def __init__(
59
60
60
61
missing_ops = self .validate_conversion ()
61
62
if missing_ops :
63
+ # TODO: @narendasan make sure to set logging.captureWarnings(True)
62
64
warnings .warn (
63
65
"Interpretation will fail due to missing operations \n "
64
66
+ "\n " .join (f"{ i } " for i in missing_ops )
65
67
)
66
68
67
- self .optimization_profiles = (
69
+ self .optimization_profiles : Optional [ List [ trt . IOptimizationProfile ]] = (
68
70
[self .builder .create_optimization_profile ()]
69
71
if any (
70
72
input_spec .shape_mode == Input ._ShapeMode .DYNAMIC
@@ -86,8 +88,8 @@ def __init__(
86
88
# Data types for TRT Module output Tensors
87
89
self .output_dtypes = output_dtypes
88
90
89
- def validate_conversion (self ):
90
- missing_converter = set ()
91
+ def validate_conversion (self ) -> Set [ str ] :
92
+ missing_converters = set ()
91
93
92
94
for node in self .module .graph .nodes :
93
95
if node .op == "call_function" and not CONVERTERS .get (node ):
@@ -98,25 +100,25 @@ def validate_conversion(self):
98
100
submod = self .fetch_attr (node .target )
99
101
submod_type = getattr (submod , "_base_class_origin" , type (submod ))
100
102
if not CONVERTERS .get (node ):
101
- missing_converter .add (f"{ node .op } { torch .typename (submod_type )} " )
103
+ missing_converter .add (f"{ node .op } { torch .typename (submod_type )} " ) # type: ignore[no-untyped-call]
102
104
103
- return missing_converter
105
+ return missing_converters
104
106
105
107
def run (
106
108
self ,
107
- workspace_size = 0 ,
108
- precision = torch .float32 ,
109
- sparse_weights = False ,
110
- disable_tf32 = False ,
111
- force_fp32_output = False ,
112
- strict_type_constraints = False ,
113
- algorithm_selector = None ,
114
- timing_cache = None ,
115
- profiling_verbosity = None ,
116
- tactic_sources = None ,
117
- max_aux_streams = None ,
118
- version_compatible = False ,
119
- optimization_level = None ,
109
+ workspace_size : int = 0 ,
110
+ precision : torch . dtype = torch .float32 , # TODO: @peri044 Needs to be expanded to set
111
+ sparse_weights : bool = False ,
112
+ disable_tf32 : bool = False ,
113
+ force_fp32_output : bool = False ,
114
+ strict_type_constraints : bool = False ,
115
+ algorithm_selector : Optional [ trt . IAlgorithmSelector ] = None ,
116
+ timing_cache : Optional [ trt . ITimingCache ] = None ,
117
+ profiling_verbosity : Optional [ trt . ProfilingVerbosity ] = None ,
118
+ tactic_sources : Optional [ int ] = None ,
119
+ max_aux_streams : Optional [ int ] = None ,
120
+ version_compatible : bool = False ,
121
+ optimization_level : Optional [ int ] = None ,
120
122
) -> TRTInterpreterResult :
121
123
"""
122
124
Build TensorRT engine with some configs.
@@ -204,7 +206,7 @@ def run(
204
206
if strict_type_constraints :
205
207
builder_config .set_flag (trt .BuilderFlag .STRICT_TYPES )
206
208
207
- if self .optimization_profiles :
209
+ if len ( self .optimization_profiles ) > 0 :
208
210
for optimization_profile in self .optimization_profiles :
209
211
builder_config .add_optimization_profile (optimization_profile )
210
212
@@ -232,7 +234,7 @@ def run(
232
234
engine , self ._input_names , self ._output_names , serialized_cache
233
235
)
234
236
235
- def run_node (self , n ) :
237
+ def run_node (self , n : torch . fx . Node ) -> torch . fx . Node :
236
238
self ._cur_node_name = str (n )
237
239
self ._cur_node = n
238
240
# add "_itensor_to_tensor_meta"
@@ -241,29 +243,31 @@ def run_node(self, n):
241
243
n .kwargs = kwargs
242
244
243
245
# run the node
244
- trt_node = super ().run_node (n )
246
+ trt_node : torch . fx . Node = super ().run_node (n )
245
247
246
248
# remove "_itensor_to_tensor_meta"
247
249
kwargs = dict (n .kwargs )
248
250
del kwargs ["_itensor_to_tensor_meta" ]
249
251
n .kwargs = kwargs
250
252
251
253
if isinstance (trt_node , trt .tensorrt .ITensor ):
252
- self ._itensor_to_tensor_meta [trt_node ] = n .meta .get ("tensor_meta" )
254
+ self ._itensor_to_tensor_meta [trt_node ] = n .meta .get ("tensor_meta" ) #type: ignore[assignment]
253
255
254
256
return trt_node
255
257
256
- def placeholder (self , target , args , kwargs ) :
258
+ def placeholder (self , target : str , args : Any , kwargs : Any ) -> trt . ITensor :
257
259
self ._input_names .append (target )
258
260
current_input = self .input_specs [self .input_specs_iter ]
259
261
self .input_specs_iter += 1
260
262
# Set optimization profile for dynamic input shape
261
- shape = current_input . shape
263
+ shape = None
262
264
if current_input .shape_mode == Input ._ShapeMode .DYNAMIC :
265
+ assert isinstance (current_input .shape , dict )
263
266
shape = []
264
267
min_shape = current_input .shape ["min_shape" ]
265
268
opt_shape = current_input .shape ["opt_shape" ]
266
269
max_shape = current_input .shape ["max_shape" ]
270
+ # TODO: Does not support disjoint optimization profiles?
267
271
self .optimization_profiles [0 ].set_shape (
268
272
target , min_shape , opt_shape , max_shape
269
273
)
@@ -274,14 +278,20 @@ def placeholder(self, target, args, kwargs):
274
278
else :
275
279
# -1 to represent the dynamic dimension
276
280
shape .append (- 1 )
281
+ elif current_input .shape_mode == Input ._ShapeMode .STATIC :
282
+ assert isinstance (current_input .shape , tuple )
283
+ shape = list (current_input .shape )
284
+ else :
285
+ raise RuntimeError (f"Unable to access shape spec for input: { target } (got: { current_input } )" )
286
+
277
287
278
288
return self .network .add_input (
279
289
name = target ,
280
290
shape = tuple (shape ),
281
291
dtype = unified_dtype_converter (current_input .torch_dtype , Frameworks .TRT ),
282
292
)
283
293
284
- def call_module (self , target , args , kwargs ):
294
+ def call_module (self , target : str , args : Any , kwargs : Any ) -> Any : #Probably should be Tuple[trt.ITensor]? Case for Any?
285
295
assert isinstance (target , str )
286
296
submod = self .fetch_attr (target )
287
297
submod_type = getattr (submod , "_base_class_origin" , type (submod ))
@@ -295,17 +305,18 @@ def call_module(self, target, args, kwargs):
295
305
assert self ._cur_node_name is not None
296
306
return converter (self .network , submod , args , kwargs , self ._cur_node_name )
297
307
298
- def call_function (self , target , args , kwargs ):
308
+ def call_function (self , target : str , args : Any , kwargs : Any ) -> Any :
309
+ #TODO: Why is this stateful? We should be able to take in the inputs
299
310
converter = CONVERTERS .get (self ._cur_node )
300
311
if not converter :
301
312
raise RuntimeError (
302
- f"Conversion of function { torch .typename (target )} not currently supported!"
313
+ f"Conversion of function { torch .typename (target )} not currently supported!" # type: ignore[no-untyped-call]
303
314
)
304
315
305
316
assert self ._cur_node_name is not None
306
317
return converter (self .network , target , args , kwargs , self ._cur_node_name )
307
318
308
- def call_method (self , target , args , kwargs ) :
319
+ def call_method (self , target : str , args : Any , kwargs : Any ) -> Any :
309
320
assert isinstance (target , str )
310
321
converter = CONVERTERS .get (self ._cur_node )
311
322
@@ -317,7 +328,7 @@ def call_method(self, target, args, kwargs):
317
328
assert self ._cur_node_name is not None
318
329
return converter (self .network , target , args , kwargs , self ._cur_node_name )
319
330
320
- def output (self , target , args , kwargs ) :
331
+ def output (self , target : str , args : Any , kwargs : Any ) -> None :
321
332
assert len (args ) == 1
322
333
if isinstance (args [0 ], tuple ):
323
334
outputs = args [0 ]
0 commit comments