1
- from typing import List , Dict , Any
1
+ from typing import List , Dict , Any , Set , Union , Callable , TypeGuard
2
+
2
3
import torch_tensorrt .ts
3
4
4
- from torch_tensorrt import logging
5
+ from torch_tensorrt import logging , Input , dtype
5
6
import torch
6
7
import torch .fx
7
8
from enum import Enum
8
9
9
10
import torch_tensorrt .fx
11
+ from torch_tensorrt .fx import InputTensorSpec
10
12
from torch_tensorrt .fx .utils import LowerPrecision
11
13
12
14
15
+ def _non_fx_input_interface (inputs : List [Input | torch .Tensor | InputTensorSpec ]) -> TypeGuard [List [Input | torch .Tensor ]]:
16
+ return all ([isinstance (i , torch .Tensor | Input ) for i in inputs ])
17
+
18
+ def _fx_input_interface (inputs : List [Input | torch .Tensor | InputTensorSpec ]) -> TypeGuard [List [InputTensorSpec | torch .Tensor ]]:
19
+ return all ([isinstance (i , torch .Tensor | InputTensorSpec ) for i in inputs ])
20
+
13
21
class _IRType (Enum ):
14
22
"""Enum to set the minimum required logging level to print a message to stdout"""
15
23
@@ -80,11 +88,11 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
80
88
81
89
def compile (
82
90
module : Any ,
83
- ir = "default" ,
84
- inputs = [],
85
- enabled_precisions = set ([torch .float ]),
86
- ** kwargs ,
87
- ):
91
+ ir : str = "default" ,
92
+ inputs : List [ Union [ Input , torch . Tensor , InputTensorSpec ]] = [],
93
+ enabled_precisions : Set [ Union [ torch . dtype , dtype ]] = set ([torch .float ]),
94
+ ** kwargs : Any ,
95
+ ) -> Union [ torch . nn . Module , torch . jit . ScriptModule , torch . fx . GraphModule , Callable [[ Any ], Any ]] :
88
96
"""Compile a PyTorch module for NVIDIA GPUs using TensorRT
89
97
90
98
Takes a existing PyTorch module and a set of settings to configure the compiler
@@ -130,9 +138,11 @@ def compile(
130
138
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript" ,
131
139
)
132
140
ts_mod = torch .jit .script (module )
133
- return torch_tensorrt .ts .compile (
141
+ assert _non_fx_input_interface (inputs )
142
+ compiled_ts_module : torch .jit .ScriptModule = torch_tensorrt .ts .compile (
134
143
ts_mod , inputs = inputs , enabled_precisions = enabled_precisions , ** kwargs
135
144
)
145
+ return compiled_ts_module
136
146
elif target_ir == _IRType .fx :
137
147
if (
138
148
torch .float16 in enabled_precisions
@@ -147,38 +157,31 @@ def compile(
147
157
else :
148
158
raise ValueError (f"Precision { enabled_precisions } not supported on FX" )
149
159
150
- return torch_tensorrt .fx .compile (
160
+ assert _fx_input_interface (inputs )
161
+ compiled_fx_module : torch .nn .Module = torch_tensorrt .fx .compile (
151
162
module ,
152
163
inputs ,
153
164
lower_precision = lower_precision ,
154
- max_batch_size = inputs [0 ].size (0 ),
155
165
explicit_batch_dimension = True ,
156
166
dynamic_batch = False ,
157
167
** kwargs ,
158
168
)
169
+ return compiled_fx_module
159
170
elif target_ir == _IRType .dynamo :
160
- from torch_tensorrt import Device
161
- from torch_tensorrt .dynamo .utils import prepare_inputs , prepare_device
162
- import collections .abc
163
-
164
- if not isinstance (inputs , collections .abc .Sequence ):
165
- inputs = [inputs ]
166
- device = kwargs .get ("device" , Device ._current_device ())
167
- torchtrt_inputs , torch_inputs = prepare_inputs (inputs , prepare_device (device ))
168
- module = torch_tensorrt .dynamo .trace (module , torch_inputs , ** kwargs )
169
171
return torch_tensorrt .dynamo .compile (
170
172
module ,
171
173
inputs = inputs ,
172
174
enabled_precisions = enabled_precisions ,
173
175
** kwargs ,
174
176
)
177
+ return compiled_aten_module
175
178
elif target_ir == _IRType .torch_compile :
176
179
return torch_compile (module , enabled_precisions = enabled_precisions , ** kwargs )
177
180
else :
178
181
raise RuntimeError ("Module is an unknown format or the ir requested is unknown" )
179
182
180
183
181
- def torch_compile (module , ** kwargs ) :
184
+ def torch_compile (module : torch . nn . Module , ** kwargs : Any ) -> Callable [..., Any ] :
182
185
"""
183
186
Returns a boxed model which is the output of torch.compile.
184
187
This does not compile the model to TRT. Execute this model on
@@ -194,11 +197,11 @@ def torch_compile(module, **kwargs):
194
197
def convert_method_to_trt_engine (
195
198
module : Any ,
196
199
method_name : str ,
197
- ir = "default" ,
198
- inputs = [],
199
- enabled_precisions = set ([torch .float ]),
200
- ** kwargs ,
201
- ):
200
+ ir : str = "default" ,
201
+ inputs : List [ Union [ Input , torch . Tensor ]] = [],
202
+ enabled_precisions : Set [ Union [ torch . dtype , dtype ]] = set ([torch .float ]),
203
+ ** kwargs : Any ,
204
+ ) -> bytes :
202
205
"""Convert a TorchScript module method to a serialized TensorRT engine
203
206
204
207
Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings
0 commit comments