1
- from typing import List , Dict , Any , Set
1
+ from typing import List , Dict , Any , Set , Union , Optional
2
2
import torch
3
3
from torch_tensorrt import _C
4
4
import torch_tensorrt ._C .ts as _ts_C
13
13
import tensorrt as trt
14
14
15
15
16
- def _internal_input_to_torch_class_input (i : _C .Input ) -> torch .classes .tensorrt ._Input :
16
+ def _internal_input_to_torch_class_input (i : _C .Input ) -> torch .classes .tensorrt ._Input : # type: ignore[name-defined]
17
17
clone = torch .classes .tensorrt ._Input ()
18
18
clone ._set_min (i .min )
19
19
clone ._set_opt (i .opt )
@@ -40,7 +40,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
40
40
)
41
41
42
42
43
- def _parse_op_precision (precision : Any ) -> _enums .dtype :
43
+ def _parse_op_precision (precision : Any ) -> _enums .dtype : # type: ignore[name-defined]
44
44
if isinstance (precision , torch .dtype ):
45
45
if precision == torch .int8 :
46
46
return _enums .dtype .int8
@@ -64,7 +64,7 @@ def _parse_op_precision(precision: Any) -> _enums.dtype:
64
64
)
65
65
66
66
67
- def _parse_enabled_precisions (precisions : Any ) -> Set :
67
+ def _parse_enabled_precisions (precisions : Any ) -> Set [ _enums . dtype ]: # type: ignore[name-defined]
68
68
parsed_precisions = set ()
69
69
if any ([isinstance (precisions , type ) for type in [list , tuple , set ]]):
70
70
for p in precisions :
@@ -74,7 +74,7 @@ def _parse_enabled_precisions(precisions: Any) -> Set:
74
74
return parsed_precisions
75
75
76
76
77
- def _parse_device_type (device : Any ) -> _enums .DeviceType :
77
+ def _parse_device_type (device : Any ) -> _enums .DeviceType : # type: ignore[name-defined]
78
78
if isinstance (device , torch .device ):
79
79
if device .type == "cuda" :
80
80
return _C .DeviceType .gpu
@@ -159,7 +159,7 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> _ts_C.TorchFallback:
159
159
return info
160
160
161
161
162
- def _parse_input_signature (input_signature : Any , depth : int = 0 ):
162
+ def _parse_input_signature (input_signature : Any , depth : int = 0 ) -> Any :
163
163
if depth > 2 :
164
164
raise AssertionError (
165
165
"Input nesting depth exceeds max supported depth, use 1 level: [A, B], or 2 level: [A, (B, C)]"
@@ -197,13 +197,16 @@ def _parse_input_signature(input_signature: Any, depth: int = 0):
197
197
if i .shape_mode == Input ._ShapeMode .STATIC :
198
198
ts_i = TorchScriptInput (shape = i .shape , dtype = i .dtype , format = i .format )
199
199
elif i .shape_mode == Input ._ShapeMode .DYNAMIC :
200
- ts_i = TorchScriptInput (
201
- min_shape = i .shape ["min_shape" ],
202
- opt_shape = i .shape ["opt_shape" ],
203
- max_shape = i .shape ["max_shape" ],
204
- dtype = i .dtype ,
205
- format = i .format ,
206
- )
200
+ if isinstance (i .shape , dict ):
201
+ ts_i = TorchScriptInput (
202
+ min_shape = i .shape ["min_shape" ],
203
+ opt_shape = i .shape ["opt_shape" ],
204
+ max_shape = i .shape ["max_shape" ],
205
+ dtype = i .dtype ,
206
+ format = i .format ,
207
+ )
208
+ else :
209
+ raise ValueError (f"Input set as dynamic, expected dictionary of shapes but found { i .shape } " )
207
210
else :
208
211
raise ValueError (
209
212
"Invalid shape mode detected for input while parsing the input_signature"
@@ -342,24 +345,24 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
342
345
343
346
344
347
def TensorRTCompileSpec (
345
- inputs = [],
346
- input_signature = None ,
347
- device = Device ._current_device (),
348
- disable_tf32 = False ,
349
- sparse_weights = False ,
350
- enabled_precisions = set (),
351
- refit = False ,
352
- debug = False ,
353
- capability = _enums .EngineCapability .default ,
354
- num_avg_timing_iters = 1 ,
355
- workspace_size = 0 ,
356
- dla_sram_size = 1048576 ,
357
- dla_local_dram_size = 1073741824 ,
358
- dla_global_dram_size = 536870912 ,
359
- truncate_long_and_double = False ,
360
- calibrator = None ,
361
- allow_shape_tensors = False ,
362
- ) -> torch .classes .tensorrt .CompileSpec :
348
+ inputs : List [ Union [ torch . Tensor , Input ]] = [],
349
+ input_signature : Optional [ Any ] = None ,
350
+ device : Union [ torch . Device , Device ] = Device ._current_device (),
351
+ disable_tf32 : bool = False ,
352
+ sparse_weights : bool = False ,
353
+ enabled_precisions : Set [ Union [ torch . dtype , _enums . dtype ]] = set (), # type: ignore[name-defined]
354
+ refit : bool = False ,
355
+ debug : bool = False ,
356
+ capability : _enums . EngineCapability = _enums .EngineCapability .default , # type: ignore[name-defined]
357
+ num_avg_timing_iters : int = 1 ,
358
+ workspace_size : int = 0 ,
359
+ dla_sram_size : int = 1048576 ,
360
+ dla_local_dram_size : int = 1073741824 ,
361
+ dla_global_dram_size : int = 536870912 ,
362
+ truncate_long_and_double : bool = False ,
363
+ calibrator : object = None ,
364
+ allow_shape_tensors : bool = False ,
365
+ ) -> torch .classes .tensorrt .CompileSpec : # type: ignore[name-defined]
363
366
"""Utility to create a formated spec dictionary for using the PyTorch TensorRT backend
364
367
365
368
Keyword Args:
0 commit comments