@@ -148,6 +148,8 @@ def compile(
148
148
module : Any ,
149
149
ir : str = "default" ,
150
150
inputs : Optional [Sequence [Input | torch .Tensor | InputTensorSpec ]] = None ,
151
+ arg_inputs : Optional [Sequence [Sequence [Any ]]] = None ,
152
+ kwarg_inputs : Optional [dict [Any , Any ]] = None ,
151
153
enabled_precisions : Optional [Set [torch .dtype | dtype ]] = None ,
152
154
** kwargs : Any ,
153
155
) -> (
@@ -180,14 +182,16 @@ def compile(
180
182
), # Dynamic input shape for input #2
181
183
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
182
184
]
183
-
185
+ arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
186
+ kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
184
187
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
185
188
ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
186
189
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)
187
190
188
191
Returns:
189
192
torch.nn.Module: Compiled Module, when run it will execute via TensorRT
190
193
"""
194
+
191
195
input_list = inputs if inputs is not None else []
192
196
enabled_precisions_set : Set [dtype | torch .dtype ] = (
193
197
enabled_precisions
@@ -238,17 +242,33 @@ def compile(
238
242
return compiled_fx_module
239
243
elif target_ir == _IRType .dynamo :
240
244
# Prepare torch and torchtrt inputs
245
+ if not arg_inputs and not inputs :
246
+ raise AssertionError ("'arg_input' or 'input' should not be None." )
247
+
248
+ elif arg_inputs and inputs :
249
+ raise AssertionError (
250
+ "'arg_input' and 'input' should not be used at the same time."
251
+ )
252
+ arg_inputs = inputs or arg_inputs
253
+
254
+ if kwarg_inputs is None :
255
+ kwarg_inputs = {}
256
+
241
257
from torch_tensorrt .dynamo .utils import prepare_inputs
242
258
243
- if not isinstance (input_list , collections .abc .Sequence ):
244
- input_list = [input_list ]
259
+ if not isinstance (arg_inputs , collections .abc .Sequence ):
260
+ arg_inputs = [arg_inputs ] # type: ignore
245
261
246
262
# Export the module
247
- torchtrt_inputs = prepare_inputs (input_list )
248
- exp_program = dynamo_trace (module , torchtrt_inputs , ** kwargs )
263
+ torchtrt_arg_inputs = prepare_inputs (arg_inputs )
264
+ torchtrt_kwarg_inputs = prepare_inputs (kwarg_inputs )
265
+
266
+ exp_program = dynamo_trace (
267
+ module , torchtrt_arg_inputs , kwarg_inputs = torchtrt_kwarg_inputs , ** kwargs
268
+ )
249
269
trt_graph_module = dynamo_compile (
250
270
exp_program ,
251
- inputs = torchtrt_inputs ,
271
+ arg_inputs = torchtrt_arg_inputs ,
252
272
enabled_precisions = enabled_precisions_set ,
253
273
** kwargs ,
254
274
)
@@ -280,7 +300,9 @@ def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any:
280
300
def convert_method_to_trt_engine (
281
301
module : Any ,
282
302
method_name : str = "forward" ,
283
- inputs : Optional [Sequence [Input | torch .Tensor ]] = None ,
303
+ inputs : Optional [Sequence [Input | torch .Tensor | InputTensorSpec ]] = None ,
304
+ arg_inputs : Optional [Sequence [Sequence [Any ]]] = None ,
305
+ kwarg_inputs : Optional [dict [Any , Any ]] = None ,
284
306
ir : str = "default" ,
285
307
enabled_precisions : Optional [Set [torch .dtype | dtype ]] = None ,
286
308
** kwargs : Any ,
@@ -309,6 +331,8 @@ def convert_method_to_trt_engine(
309
331
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
310
332
]
311
333
334
+ arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
335
+ kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
312
336
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
313
337
ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
314
338
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)
@@ -330,7 +354,7 @@ def convert_method_to_trt_engine(
330
354
ts_mod = torch .jit .script (module )
331
355
serialized_engine : bytes = ts_convert_method_to_trt_engine (
332
356
ts_mod ,
333
- inputs = inputs ,
357
+ inputs = arg_inputs ,
334
358
method_name = method_name ,
335
359
enabled_precisions = enabled_precisions_set ,
336
360
** kwargs ,
@@ -342,18 +366,35 @@ def convert_method_to_trt_engine(
342
366
)
343
367
elif target_ir == _IRType .dynamo :
344
368
# Prepare torch and torchtrt inputs
369
+ if not arg_inputs and not inputs :
370
+ raise AssertionError ("'arg_input' or 'input' should not be None." )
371
+
372
+ elif arg_inputs and inputs :
373
+ raise AssertionError (
374
+ "'arg_input' and 'input' should not be used at the same time."
375
+ )
376
+ arg_inputs = arg_inputs or inputs
377
+
378
+ if kwarg_inputs is None :
379
+ kwarg_inputs = {}
380
+
345
381
from torch_tensorrt .dynamo .utils import prepare_inputs
346
382
347
- if not isinstance (inputs , collections .abc .Sequence ):
348
- inputs = [inputs ]
383
+ if not isinstance (arg_inputs , collections .abc .Sequence ):
384
+ arg_inputs = [arg_inputs ] # type: ignore
349
385
350
386
# Export the module
351
- torchtrt_inputs = prepare_inputs (inputs )
352
- exp_program = torch_tensorrt .dynamo .trace (module , torchtrt_inputs , ** kwargs )
387
+ torchtrt_arg_inputs = prepare_inputs (arg_inputs )
388
+ torchtrt_kwarg_inputs = prepare_inputs (kwarg_inputs )
389
+
390
+ exp_program = torch_tensorrt .dynamo .trace (
391
+ module , torchtrt_arg_inputs , kwarg_inputs = torchtrt_kwarg_inputs ** kwargs
392
+ )
353
393
354
394
return dynamo_convert_module_to_trt_engine (
355
395
exp_program ,
356
- inputs = tuple (inputs ),
396
+ arg_inputs = tuple (arg_inputs ),
397
+ kwarg_inputs = torchtrt_kwarg_inputs ,
357
398
enabled_precisions = enabled_precisions_set ,
358
399
** kwargs ,
359
400
)
@@ -408,6 +449,7 @@ def save(
408
449
* ,
409
450
output_format : str = "exported_program" ,
410
451
inputs : Optional [Sequence [torch .Tensor ]] = None ,
452
+ arg_inputs : Optional [Sequence [torch .Tensor ]] = None ,
411
453
kwargs_inputs : Optional [dict [str , Any ]] = None ,
412
454
retrace : bool = False ,
413
455
) -> None :
@@ -417,18 +459,27 @@ def save(
417
459
Arguments:
418
460
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)): Compiled Torch-TensorRT module
419
461
inputs (torch.Tensor): Torch input tensors
462
+ arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
463
+ kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
420
464
output_format (str): Format to save the model. Options include exported_program | torchscript.
421
465
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
422
466
This flag is experimental for now.
423
467
"""
424
468
module_type = _parse_module_type (module )
425
469
accepted_formats = {"exported_program" , "torchscript" }
426
- if inputs is not None and not all (
427
- isinstance (input , torch .Tensor ) for input in inputs
470
+ if arg_inputs is not None and not all (
471
+ isinstance (input , torch .Tensor ) for input in arg_inputs
428
472
):
429
473
raise ValueError (
430
474
"Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
431
475
)
476
+ if arg_inputs and inputs :
477
+ raise AssertionError (
478
+ "'arg_input' and 'input' should not be used at the same time."
479
+ )
480
+
481
+ arg_inputs = inputs or arg_inputs
482
+
432
483
if kwargs_inputs is None :
433
484
kwargs_inputs = {}
434
485
@@ -460,27 +511,27 @@ def save(
460
511
else :
461
512
torch .export .save (module , file_path )
462
513
elif module_type == _ModuleType .fx :
463
- if inputs is None :
514
+ if arg_inputs is None :
464
515
raise ValueError (
465
516
"Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model"
466
517
)
467
518
# The module type is torch.fx.GraphModule
468
519
if output_format == "torchscript" :
469
520
module_ts = torch .jit .trace (
470
- module , inputs , example_kwarg_inputs = kwargs_inputs
521
+ module , arg_inputs , example_kwarg_inputs = kwargs_inputs
471
522
)
472
523
torch .jit .save (module_ts , file_path )
473
524
else :
474
525
if not retrace :
475
526
from torch_tensorrt .dynamo ._exporter import export
476
527
477
- exp_program = export (module , inputs , kwargs_inputs )
528
+ exp_program = export (module , arg_inputs , kwargs_inputs )
478
529
torch .export .save (exp_program , file_path )
479
530
else :
480
531
from torch ._higher_order_ops .torchbind import enable_torchbind_tracing
481
532
482
533
with enable_torchbind_tracing ():
483
534
exp_program = torch .export .export (
484
- module , tuple (inputs ), kwargs = kwargs_inputs , strict = False
535
+ module , tuple (arg_inputs ), kwargs = kwargs_inputs , strict = False
485
536
)
486
537
torch .export .save (exp_program , file_path )
0 commit comments