@@ -73,7 +73,7 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
73
73
)
74
74
return _IRType .ts
75
75
else :
76
- raise ValueError ("Module was provided with in an unsupported format" )
76
+ raise ValueError ("Module was provided in an unsupported format" )
77
77
else :
78
78
raise ValueError ("Unknown ir was requested" )
79
79
@@ -157,37 +157,38 @@ def compile(
157
157
** kwargs ,
158
158
)
159
159
elif target_ir == _IRType .dynamo :
160
+ from torch_tensorrt import Device
161
+ from torch_tensorrt .dynamo .utils import prepare_inputs , prepare_device
162
+ import collections .abc
163
+
164
+ if not isinstance (inputs , collections .abc .Sequence ):
165
+ inputs = [inputs ]
166
+ device = kwargs .get ("device" , Device ._current_device ())
167
+ torchtrt_inputs , torch_inputs = prepare_inputs (inputs , prepare_device (device ))
168
+ module = torch_tensorrt .dynamo .trace (module , torch_inputs , ** kwargs )
160
169
return torch_tensorrt .dynamo .compile (
161
170
module ,
162
171
inputs = inputs ,
163
172
enabled_precisions = enabled_precisions ,
164
173
** kwargs ,
165
174
)
166
175
elif target_ir == _IRType .torch_compile :
167
- return torch_compile (
168
- module , inputs , enabled_precisions = enabled_precisions , ** kwargs
169
- )
176
+ return torch_compile (module , enabled_precisions = enabled_precisions , ** kwargs )
170
177
else :
171
178
raise RuntimeError ("Module is an unknown format or the ir requested is unknown" )
172
179
173
180
174
- def torch_compile (module , inputs , ** kwargs ):
175
-
176
- from torch_tensorrt .dynamo .utils import prepare_inputs , prepare_device
181
+ def torch_compile (module , ** kwargs ):
182
+ """
183
+ Returns a boxed model which is the output of torch.compile.
184
+ This does not compile the model to TRT. Execute this model on
185
+ sample inputs to compile the model to TRT.
186
+ """
177
187
from torch_tensorrt .dynamo .backend import torch_tensorrt_backend
178
- from torch_tensorrt import Device
179
- import collections .abc
180
-
181
- if not isinstance (inputs , collections .abc .Sequence ):
182
- inputs = [inputs ]
183
188
184
- device = kwargs .get ("device" , Device ._current_device ())
185
- torchtrt_inputs , torch_inputs = prepare_inputs (inputs , prepare_device (device ))
186
- model = torch .compile (module , backend = torch_tensorrt_backend , options = {** kwargs })
187
- # Ensure compilation occurs by calling the function with provided inputs
188
- model (* torch_inputs )
189
+ boxed_fn = torch .compile (module , backend = torch_tensorrt_backend , options = {** kwargs })
189
190
190
- return model
191
+ return boxed_fn
191
192
192
193
193
194
def convert_method_to_trt_engine (
@@ -246,6 +247,16 @@ def convert_method_to_trt_engine(
246
247
** kwargs ,
247
248
)
248
249
elif target_ir == _IRType .fx :
249
- raise RuntimeError ("fx is currently not supported" )
250
+ raise RuntimeError (
251
+ "convert_method_to_trt_engine call is not supported for ir=fx"
252
+ )
253
+ elif target_ir == _IRType .dynamo :
254
+ raise RuntimeError (
255
+ "convert_method_to_trt_engine call is not supported for ir=dynamo."
256
+ )
257
+ elif target_ir == _IRType .torch_compile :
258
+ raise RuntimeError (
259
+ "convert_method_to_trt_engine call is not supported for ir=torch_compile"
260
+ )
250
261
else :
251
262
raise RuntimeError ("Module is an unknown format or the ir requested is unknown" )
0 commit comments