1
- from typing import List , Dict , Any
1
+ from typing import List , Dict , Any , Tuple , Union , Set , Optional
2
2
import torch
3
3
from torch import nn
4
4
7
7
from torch_tensorrt .ts ._compile_spec import _parse_compile_spec , _parse_device
8
8
from torch_tensorrt ._Device import Device
9
9
from types import FunctionType
10
+ from torch_tensorrt ._Input import Input
10
11
11
12
12
13
def compile (
13
14
module : torch .jit .ScriptModule ,
14
- inputs = [],
15
- input_signature = None ,
16
- device = Device ._current_device (),
17
- disable_tf32 = False ,
18
- sparse_weights = False ,
19
- enabled_precisions = set (),
20
- refit = False ,
21
- debug = False ,
22
- capability = _enums .EngineCapability .default ,
23
- num_avg_timing_iters = 1 ,
24
- workspace_size = 0 ,
25
- dla_sram_size = 1048576 ,
26
- dla_local_dram_size = 1073741824 ,
27
- dla_global_dram_size = 536870912 ,
28
- calibrator = None ,
29
- truncate_long_and_double = False ,
30
- require_full_compilation = False ,
31
- min_block_size = 3 ,
32
- torch_executed_ops = [],
33
- torch_executed_modules = [],
34
- allow_shape_tensors = False ,
15
+ inputs : List [ Union [ Input , torch . Tensor ]] = [],
16
+ input_signature : Optional [ Tuple [ Union [ Input , torch . Tensor ]]] = None ,
17
+ device : Device = Device ._current_device (),
18
+ disable_tf32 : bool = False ,
19
+ sparse_weights : bool = False ,
20
+ enabled_precisions : Set [ Union [ torch . dtype , _enums . dtype ]] = set (), # type: ignore[name-defined]
21
+ refit : bool = False ,
22
+ debug : bool = False ,
23
+ capability : _enums . EngineCapability = _enums .EngineCapability .default , # type: ignore[name-defined]
24
+ num_avg_timing_iters : int = 1 ,
25
+ workspace_size : int = 0 ,
26
+ dla_sram_size : int = 1048576 ,
27
+ dla_local_dram_size : int = 1073741824 ,
28
+ dla_global_dram_size : int = 536870912 ,
29
+ calibrator : object = None ,
30
+ truncate_long_and_double : bool = False ,
31
+ require_full_compilation : bool = False ,
32
+ min_block_size : int = 3 ,
33
+ torch_executed_ops : List [ str ] = [],
34
+ torch_executed_modules : List [ str ] = [],
35
+ allow_shape_tensors : bool = False ,
35
36
) -> torch .jit .ScriptModule :
36
37
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
37
38
@@ -137,30 +138,30 @@ def compile(
137
138
}
138
139
139
140
compiled_cpp_mod = _C .compile_graph (module ._c , _parse_compile_spec (spec ))
140
- compiled_module = torch .jit ._recursive .wrap_cpp_module (compiled_cpp_mod )
141
+ compiled_module : torch . jit . ScriptModule = torch .jit ._recursive .wrap_cpp_module (compiled_cpp_mod ) # type: ignore[no-untyped-call]
141
142
return compiled_module
142
143
143
144
144
145
def convert_method_to_trt_engine (
145
146
module : torch .jit .ScriptModule ,
146
147
method_name : str = "forward" ,
147
- inputs = [],
148
- device = Device ._current_device (),
149
- disable_tf32 = False ,
150
- sparse_weights = False ,
151
- enabled_precisions = set (),
152
- refit = False ,
153
- debug = False ,
154
- capability = _enums .EngineCapability .default ,
155
- num_avg_timing_iters = 1 ,
156
- workspace_size = 0 ,
157
- dla_sram_size = 1048576 ,
158
- dla_local_dram_size = 1073741824 ,
159
- dla_global_dram_size = 536870912 ,
160
- truncate_long_and_double = False ,
161
- calibrator = None ,
162
- allow_shape_tensors = False ,
163
- ) -> bytearray :
148
+ inputs : List [ Union [ Input , torch . Tensor ]] = [],
149
+ device : Device = Device ._current_device (),
150
+ disable_tf32 : bool = False ,
151
+ sparse_weights : bool = False ,
152
+ enabled_precisions : Set [ Union [ torch . dtype , _enums . dtype ]] = set (), # type: ignore[name-defined]
153
+ refit : bool = False ,
154
+ debug : bool = False ,
155
+ capability : _enums . EngineCapability = _enums .EngineCapability .default , # type: ignore[name-defined]
156
+ num_avg_timing_iters : int = 1 ,
157
+ workspace_size : int = 0 ,
158
+ dla_sram_size : int = 1048576 ,
159
+ dla_local_dram_size : int = 1073741824 ,
160
+ dla_global_dram_size : int = 536870912 ,
161
+ truncate_long_and_double : int = False ,
162
+ calibrator : object = None ,
163
+ allow_shape_tensors : bool = False ,
164
+ ) -> bytes :
164
165
"""Convert a TorchScript module method to a serialized TensorRT engine
165
166
166
167
Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings
@@ -221,7 +222,7 @@ def convert_method_to_trt_engine(
221
222
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
222
223
223
224
Returns:
224
- bytearray : Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
225
+ bytes : Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
225
226
"""
226
227
if isinstance (module , torch .jit .ScriptFunction ):
227
228
raise TypeError (
@@ -293,8 +294,8 @@ def embed_engine_in_new_module(
293
294
input_binding_names ,
294
295
output_binding_names ,
295
296
)
296
- return torch .jit ._recursive .wrap_cpp_module (cpp_mod )
297
-
297
+ wrapped_mod : torch .jit .ScriptModule = torch . jit . _recursive .wrap_cpp_module (cpp_mod ) # type: ignore[no-untyped-call]
298
+ return wrapped_mod
298
299
299
300
def check_method_op_support (
300
301
module : torch .jit .ScriptModule , method_name : str = "forward"
@@ -312,4 +313,5 @@ def check_method_op_support(
312
313
Returns:
313
314
bool: True if supported Method
314
315
"""
315
- return _C .check_method_op_support (module ._c , method_name )
316
+ supported : bool = _C .check_method_op_support (module ._c , method_name )
317
+ return supported
0 commit comments