2
2
import warnings
3
3
from datetime import datetime
4
4
from packaging import version
5
- from typing import Any , Callable , Dict , List , NamedTuple , Optional , Sequence
5
+ from typing import Any , Callable , Dict , List , NamedTuple , Optional , Sequence , Set
6
6
7
7
import numpy
8
8
9
9
# @manual=//deeplearning/trt/python:py_tensorrt
10
10
import tensorrt as trt
11
11
import torch
12
12
import torch .fx
13
+ from torch .fx .node import Target
13
14
from torch ._ops import OpOverload
14
15
from torch .fx .node import _get_qualified_name
15
16
from torch .fx .passes .shape_prop import TensorMetadata
@@ -42,12 +43,13 @@ def __init__(
42
43
self ,
43
44
module : torch .fx .GraphModule ,
44
45
input_specs : List [Input ],
45
- logger_level = None ,
46
- output_dtypes = None ,
46
+ logger_level : trt . ILogger . Severity = trt . ILogger . Severity . WARNING ,
47
+ output_dtypes : Optional [ List [ torch . dtype ]] = None ,
47
48
):
48
49
super ().__init__ (module )
49
50
50
- self .logger = trt .Logger (logger_level or trt .Logger .WARNING )
51
+ # TODO: @narendasan replace with Torch-TensorRT Logger
52
+ self .logger = trt .Logger (logger_level )
51
53
self .builder = trt .Builder (self .logger )
52
54
53
55
flag = 0
@@ -60,12 +62,13 @@ def __init__(
60
62
61
63
missing_ops = self .validate_conversion ()
62
64
if missing_ops :
65
+ # TODO: @narendasan make sure to set logging.captureWarnings(True)
63
66
warnings .warn (
64
67
"Interpretation will fail due to missing operations \n "
65
68
+ "\n " .join (f"{ i } " for i in missing_ops )
66
69
)
67
70
68
- self .optimization_profiles : Optional [ List ] = None
71
+ self .optimization_profiles : List [ trt . IOptimizationProfile ] = []
69
72
self .input_specs = input_specs
70
73
self .input_specs_iter = 0
71
74
self ._cur_node_name : Optional [str ] = None
@@ -78,37 +81,37 @@ def __init__(
78
81
# Data types for TRT Module output Tensors
79
82
self .output_dtypes = output_dtypes
80
83
81
- def validate_conversion (self ):
82
- missing_converter = set ()
84
+ def validate_conversion (self ) -> Set [ str ] :
85
+ missing_converters = set ()
83
86
84
87
for node in self .module .graph .nodes :
85
88
if node .op == "call_function" and not CONVERTERS .get (node .target ):
86
- missing_converter .add (f"{ node .op } { _get_qualified_name (node .target )} " )
89
+ missing_converters .add (f"{ node .op } { _get_qualified_name (node .target )} " )
87
90
elif node .op == "call_method" and not CONVERTERS .get (node .target ):
88
- missing_converter .add (f"{ node .op } torch.Tensor.{ node .target } " )
91
+ missing_converters .add (f"{ node .op } torch.Tensor.{ node .target } " )
89
92
elif node .op == "call_module" :
90
93
submod = self .fetch_attr (node .target )
91
94
submod_type = getattr (submod , "_base_class_origin" , type (submod ))
92
95
if not CONVERTERS .get (submod_type ):
93
- missing_converter .add (f"{ node .op } { torch .typename (submod_type )} " )
96
+ missing_converters .add (f"{ node .op } { torch .typename (submod_type )} " ) # type: ignore[no-untyped-call]
94
97
95
- return missing_converter
98
+ return missing_converters
96
99
97
100
def run (
98
101
self ,
99
- workspace_size = 0 ,
100
- precision = torch .float32 ,
101
- sparse_weights = False ,
102
- disable_tf32 = False ,
103
- force_fp32_output = False ,
104
- strict_type_constraints = False ,
105
- algorithm_selector = None ,
106
- timing_cache = None ,
107
- profiling_verbosity = None ,
108
- tactic_sources = None ,
109
- max_aux_streams = None ,
110
- version_compatible = False ,
111
- optimization_level = None ,
102
+ workspace_size : int = 0 ,
103
+ precision : torch . dtype = torch .float32 , # TODO: @peri044 Needs to be expanded to set
104
+ sparse_weights : bool = False ,
105
+ disable_tf32 : bool = False ,
106
+ force_fp32_output : bool = False ,
107
+ strict_type_constraints : bool = False ,
108
+ algorithm_selector : Optional [ trt . IAlgorithmSelector ] = None ,
109
+ timing_cache : Optional [ trt . ITimingCache ] = None ,
110
+ profiling_verbosity : Optional [ trt . ProfilingVerbosity ] = None ,
111
+ tactic_sources : Optional [ int ] = None ,
112
+ max_aux_streams : Optional [ int ] = None ,
113
+ version_compatible : bool = False ,
114
+ optimization_level : Optional [ int ] = None ,
112
115
) -> TRTInterpreterResult :
113
116
"""
114
117
Build TensorRT engine with some configs.
@@ -196,7 +199,7 @@ def run(
196
199
if strict_type_constraints :
197
200
builder_config .set_flag (trt .BuilderFlag .STRICT_TYPES )
198
201
199
- if self .optimization_profiles :
202
+ if len ( self .optimization_profiles ) > 0 :
200
203
for optimization_profile in self .optimization_profiles :
201
204
builder_config .add_optimization_profile (optimization_profile )
202
205
@@ -224,55 +227,63 @@ def run(
224
227
engine , self ._input_names , self ._output_names , serialized_cache
225
228
)
226
229
227
- def run_node (self , n ) :
230
+ def run_node (self , n : torch . fx . Node ) -> torch . fx . Node :
228
231
self ._cur_node_name = str (n )
229
232
# add "_itensor_to_tensor_meta"
230
233
kwargs = dict (n .kwargs )
231
234
kwargs ["_itensor_to_tensor_meta" ] = self ._itensor_to_tensor_meta
232
235
n .kwargs = kwargs
233
236
234
237
# run the node
235
- trt_node = super ().run_node (n )
238
+ trt_node : torch . fx . Node = super ().run_node (n )
236
239
237
240
# remove "_itensor_to_tensor_meta"
238
241
kwargs = dict (n .kwargs )
239
242
del kwargs ["_itensor_to_tensor_meta" ]
240
243
n .kwargs = kwargs
241
244
242
245
if isinstance (trt_node , trt .tensorrt .ITensor ):
243
- self ._itensor_to_tensor_meta [trt_node ] = n .meta .get ("tensor_meta" )
246
+ self ._itensor_to_tensor_meta [trt_node ] = n .meta .get ("tensor_meta" ) #type: ignore[assignment]
244
247
245
248
return trt_node
246
249
247
- def placeholder (self , target , args , kwargs ) :
250
+ def placeholder (self , target : str , args : Any , kwargs : Any ) -> trt . ITensor :
248
251
self ._input_names .append (target )
249
252
current_input = self .input_specs [self .input_specs_iter ]
250
253
self .input_specs_iter += 1
251
254
# Set optimization profile for dynamic input shape
252
- shape = current_input . shape
255
+ shape = None
253
256
if current_input .shape_mode == Input ._ShapeMode .DYNAMIC :
257
+ assert isinstance (current_input .shape , dict )
254
258
shape = []
255
259
min_shape = current_input .shape ["min_shape" ]
256
260
opt_shape = current_input .shape ["opt_shape" ]
257
261
max_shape = current_input .shape ["max_shape" ]
258
- self .optimization_profiles [0 ].set_shape (
262
+ # TODO: Does not support disjoint optimization profiles?
263
+ self .optimization_profiles .append (trt .IOptimizationProfile ().set_shape (
259
264
target , [min_shape , opt_shape , max_shape ]
260
- )
265
+ ))
261
266
assert len (min_shape ) == len (opt_shape ) == len (max_shape )
262
267
for i in range (len (min_shape )):
263
268
if min_shape [i ] == opt_shape [i ] == max_shape [i ]:
264
269
shape .append (min_shape [i ])
265
270
else :
266
271
# -1 to represent the dynamic dimension
267
272
shape .append (- 1 )
273
+ elif current_input .shape_mode == Input ._ShapeMode .STATIC :
274
+ assert isinstance (current_input .shape , tuple )
275
+ shape = list (current_input .shape )
276
+ else :
277
+ raise RuntimeError (f"Unable to access shape spec for input: { target } (got: { current_input } )" )
278
+
268
279
269
280
return self .network .add_input (
270
281
name = target ,
271
282
shape = tuple (shape ),
272
283
dtype = unified_dtype_converter (current_input .torch_dtype , Frameworks .TRT ),
273
284
)
274
285
275
- def call_module (self , target , args , kwargs ):
286
+ def call_module (self , target : str , args : Any , kwargs : Any ) -> Any : #Probably should be Tuple[trt.ITensor]? Case for Any?
276
287
assert isinstance (target , str )
277
288
submod = self .fetch_attr (target )
278
289
submod_type = getattr (submod , "_base_class_origin" , type (submod ))
@@ -286,17 +297,17 @@ def call_module(self, target, args, kwargs):
286
297
assert self ._cur_node_name is not None
287
298
return converter (self .network , submod , args , kwargs , self ._cur_node_name )
288
299
289
- def call_function (self , target , args , kwargs ) :
300
+ def call_function (self , target : str , args : Any , kwargs : Any ) -> Any :
290
301
converter = CONVERTERS .get (target )
291
302
if not converter :
292
303
raise RuntimeError (
293
- f"Conversion of function { torch .typename (target )} not currently supported!"
304
+ f"Conversion of function { torch .typename (target )} not currently supported!" # type: ignore[no-untyped-call]
294
305
)
295
306
296
307
assert self ._cur_node_name is not None
297
308
return converter (self .network , target , args , kwargs , self ._cur_node_name )
298
309
299
- def call_method (self , target , args , kwargs ) :
310
+ def call_method (self , target : str , args : Any , kwargs : Any ) -> Any :
300
311
assert isinstance (target , str )
301
312
converter = CONVERTERS .get (target )
302
313
@@ -308,7 +319,7 @@ def call_method(self, target, args, kwargs):
308
319
assert self ._cur_node_name is not None
309
320
return converter (self .network , target , args , kwargs , self ._cur_node_name )
310
321
311
- def output (self , target , args , kwargs ) :
322
+ def output (self , target : str , args : Any , kwargs : Any ) -> None :
312
323
assert len (args ) == 1
313
324
if isinstance (args [0 ], tuple ):
314
325
outputs = args [0 ]
0 commit comments