Skip to content

Commit 2a17ebd

Browse files
committed
Changed the user API to include inputs, arg_inputs, kwarg_inputs
1 parent 8ecef1c commit 2a17ebd

File tree

3 files changed

+125
-30
lines changed

3 files changed

+125
-30
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ def compile(
148148
module: Any,
149149
ir: str = "default",
150150
inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None,
151+
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
152+
kwarg_inputs: Optional[dict[Any, Any]] = None,
151153
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
152154
**kwargs: Any,
153155
) -> (
@@ -180,14 +182,16 @@ def compile(
180182
), # Dynamic input shape for input #2
181183
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
182184
]
183-
185+
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
186+
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
184187
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
185188
ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
186189
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)
187190
188191
Returns:
189192
torch.nn.Module: Compiled Module, when run it will execute via TensorRT
190193
"""
194+
191195
input_list = inputs if inputs is not None else []
192196
enabled_precisions_set: Set[dtype | torch.dtype] = (
193197
enabled_precisions
@@ -238,17 +242,33 @@ def compile(
238242
return compiled_fx_module
239243
elif target_ir == _IRType.dynamo:
240244
# Prepare torch and torchtrt inputs
245+
if not arg_inputs and not inputs:
246+
raise AssertionError("'arg_input' or 'input' should not be None.")
247+
248+
elif arg_inputs and inputs:
249+
raise AssertionError(
250+
"'arg_input' and 'input' should not be used at the same time."
251+
)
252+
arg_inputs = inputs or arg_inputs
253+
254+
if kwarg_inputs is None:
255+
kwarg_inputs = {}
256+
241257
from torch_tensorrt.dynamo.utils import prepare_inputs
242258

243-
if not isinstance(input_list, collections.abc.Sequence):
244-
input_list = [input_list]
259+
if not isinstance(arg_inputs, collections.abc.Sequence):
260+
arg_inputs = [arg_inputs] # type: ignore
245261

246262
# Export the module
247-
torchtrt_inputs = prepare_inputs(input_list)
248-
exp_program = dynamo_trace(module, torchtrt_inputs, **kwargs)
263+
torchtrt_arg_inputs = prepare_inputs(arg_inputs)
264+
torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)
265+
266+
exp_program = dynamo_trace(
267+
module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, **kwargs
268+
)
249269
trt_graph_module = dynamo_compile(
250270
exp_program,
251-
inputs=torchtrt_inputs,
271+
arg_inputs=torchtrt_arg_inputs,
252272
enabled_precisions=enabled_precisions_set,
253273
**kwargs,
254274
)
@@ -280,7 +300,9 @@ def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any:
280300
def convert_method_to_trt_engine(
281301
module: Any,
282302
method_name: str = "forward",
283-
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
303+
inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None,
304+
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
305+
kwarg_inputs: Optional[dict[Any, Any]] = None,
284306
ir: str = "default",
285307
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
286308
**kwargs: Any,
@@ -309,6 +331,8 @@ def convert_method_to_trt_engine(
309331
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
310332
]
311333
334+
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
335+
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
312336
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
313337
ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
314338
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)
@@ -330,7 +354,7 @@ def convert_method_to_trt_engine(
330354
ts_mod = torch.jit.script(module)
331355
serialized_engine: bytes = ts_convert_method_to_trt_engine(
332356
ts_mod,
333-
inputs=inputs,
357+
inputs=arg_inputs,
334358
method_name=method_name,
335359
enabled_precisions=enabled_precisions_set,
336360
**kwargs,
@@ -342,18 +366,35 @@ def convert_method_to_trt_engine(
342366
)
343367
elif target_ir == _IRType.dynamo:
344368
# Prepare torch and torchtrt inputs
369+
if not arg_inputs and not inputs:
370+
raise AssertionError("'arg_input' or 'input' should not be None.")
371+
372+
elif arg_inputs and inputs:
373+
raise AssertionError(
374+
"'arg_input' and 'input' should not be used at the same time."
375+
)
376+
arg_inputs = arg_inputs or inputs
377+
378+
if kwarg_inputs is None:
379+
kwarg_inputs = {}
380+
345381
from torch_tensorrt.dynamo.utils import prepare_inputs
346382

347-
if not isinstance(inputs, collections.abc.Sequence):
348-
inputs = [inputs]
383+
if not isinstance(arg_inputs, collections.abc.Sequence):
384+
arg_inputs = [arg_inputs] # type: ignore
349385

350386
# Export the module
351-
torchtrt_inputs = prepare_inputs(inputs)
352-
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
387+
torchtrt_arg_inputs = prepare_inputs(arg_inputs)
388+
torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)
389+
390+
exp_program = torch_tensorrt.dynamo.trace(
391+
module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs**kwargs
392+
)
353393

354394
return dynamo_convert_module_to_trt_engine(
355395
exp_program,
356-
inputs=tuple(inputs),
396+
arg_inputs=tuple(arg_inputs),
397+
kwarg_inputs=torchtrt_kwarg_inputs,
357398
enabled_precisions=enabled_precisions_set,
358399
**kwargs,
359400
)
@@ -408,6 +449,7 @@ def save(
408449
*,
409450
output_format: str = "exported_program",
410451
inputs: Optional[Sequence[torch.Tensor]] = None,
452+
arg_inputs: Optional[Sequence[torch.Tensor]] = None,
411453
kwargs_inputs: Optional[dict[str, Any]] = None,
412454
retrace: bool = False,
413455
) -> None:
@@ -417,18 +459,27 @@ def save(
417459
Arguments:
418460
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)): Compiled Torch-TensorRT module
419461
inputs (torch.Tensor): Torch input tensors
462+
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
463+
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
420464
output_format (str): Format to save the model. Options include exported_program | torchscript.
421465
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
422466
This flag is experimental for now.
423467
"""
424468
module_type = _parse_module_type(module)
425469
accepted_formats = {"exported_program", "torchscript"}
426-
if inputs is not None and not all(
427-
isinstance(input, torch.Tensor) for input in inputs
470+
if arg_inputs is not None and not all(
471+
isinstance(input, torch.Tensor) for input in arg_inputs
428472
):
429473
raise ValueError(
430474
"Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
431475
)
476+
if arg_inputs and inputs:
477+
raise AssertionError(
478+
"'arg_input' and 'input' should not be used at the same time."
479+
)
480+
481+
arg_inputs = inputs or arg_inputs
482+
432483
if kwargs_inputs is None:
433484
kwargs_inputs = {}
434485

@@ -460,27 +511,27 @@ def save(
460511
else:
461512
torch.export.save(module, file_path)
462513
elif module_type == _ModuleType.fx:
463-
if inputs is None:
514+
if arg_inputs is None:
464515
raise ValueError(
465516
"Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model"
466517
)
467518
# The module type is torch.fx.GraphModule
468519
if output_format == "torchscript":
469520
module_ts = torch.jit.trace(
470-
module, inputs, example_kwarg_inputs=kwargs_inputs
521+
module, arg_inputs, example_kwarg_inputs=kwargs_inputs
471522
)
472523
torch.jit.save(module_ts, file_path)
473524
else:
474525
if not retrace:
475526
from torch_tensorrt.dynamo._exporter import export
476527

477-
exp_program = export(module, inputs, kwargs_inputs)
528+
exp_program = export(module, arg_inputs, kwargs_inputs)
478529
torch.export.save(exp_program, file_path)
479530
else:
480531
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
481532

482533
with enable_torchbind_tracing():
483534
exp_program = torch.export.export(
484-
module, tuple(inputs), kwargs=kwargs_inputs, strict=False
535+
module, tuple(arg_inputs), kwargs=kwargs_inputs, strict=False
485536
)
486537
torch.export.save(exp_program, file_path)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@
4848

4949
def compile(
5050
exported_program: ExportedProgram,
51-
inputs: Sequence[Any],
51+
inputs: Optional[Sequence[Sequence[Any]]] = None,
5252
*,
53+
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
5354
kwarg_inputs: Optional[dict[Any, Any]] = None,
5455
device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
5556
disable_tf32: bool = _defaults.DISABLE_TF32,
@@ -111,6 +112,8 @@ def compile(
111112
]
112113
113114
Keyword Arguments:
115+
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
116+
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
114117
device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
115118
116119
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
@@ -183,13 +186,21 @@ def compile(
183186
)
184187

185188
# Aliasing inputs to arg_inputs for better understanding
186-
arg_inputs = inputs
189+
if not arg_inputs and not inputs:
190+
raise AssertionError("'arg_input' or 'input' should not be None.")
191+
192+
elif arg_inputs and inputs:
193+
raise AssertionError(
194+
"'arg_input' and 'input' should not be used at the same time."
195+
)
196+
197+
arg_inputs = inputs or arg_inputs
187198

188199
if kwarg_inputs is None:
189200
kwarg_inputs = {}
190201

191202
if not isinstance(arg_inputs, collections.abc.Sequence):
192-
arg_inputs = [arg_inputs]
203+
arg_inputs = [arg_inputs] # type: ignore
193204

194205
# Prepare torch_trt inputs
195206
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
@@ -481,9 +492,10 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
481492

482493
def convert_module_to_trt_engine(
483494
exported_program: ExportedProgram,
484-
inputs: Sequence[Any],
485-
kwarg_inputs: Optional[dict[str, Any]] = None,
495+
inputs: Optional[Sequence[Sequence[Any]]] = None,
486496
*,
497+
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
498+
kwarg_inputs: Optional[dict[Any, Any]] = None,
487499
enabled_precisions: (
488500
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
489501
) = _defaults.ENABLED_PRECISIONS,
@@ -595,8 +607,17 @@ def convert_module_to_trt_engine(
595607
DeprecationWarning,
596608
stacklevel=2,
597609
)
610+
if not arg_inputs and not inputs:
611+
raise AssertionError("'arg_input' or 'input' should not be None.")
612+
613+
elif arg_inputs and inputs:
614+
raise AssertionError(
615+
"'arg_input' and 'input' should not be used at the same time."
616+
)
617+
618+
arg_inputs = inputs or arg_inputs
598619

599-
arg_input_list = list(inputs) if inputs is not None else []
620+
arg_input_list = list(arg_inputs) if arg_inputs is not None else []
600621
torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set()
601622
if kwarg_inputs is None:
602623
kwarg_inputs = {}

py/torch_tensorrt/dynamo/_tracer.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import Any, Tuple
4+
from typing import Any, Optional, Tuple
55

66
import torch
77
from torch.export import Dim, export
@@ -14,7 +14,10 @@
1414

1515
def trace(
1616
mod: torch.nn.Module | torch.fx.GraphModule,
17-
inputs: Tuple[Any, ...],
17+
inputs: Optional[Tuple[Any, ...]] = None,
18+
*,
19+
arg_inputs: Optional[Tuple[Any, ...]] = None,
20+
kwarg_inputs: Optional[dict[Any, Any]] = None,
1821
**kwargs: Any,
1922
) -> torch.export.ExportedProgram:
2023
"""Exports a ``torch.export.ExportedProgram`` from a ``torch.nn.Module`` or ``torch.fx.GraphModule`` specifically targeting being compiled with Torch-TensorRT
@@ -40,6 +43,8 @@ def trace(
4043
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
4144
]
4245
Keyword Arguments:
46+
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
47+
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
4348
device (Union(torch.device, dict)): Target device for TensorRT engines to run on ::
4449
4550
device=torch.device("cuda:0")
@@ -52,14 +57,27 @@ def trace(
5257
"""
5358

5459
# Set log level at the top of compilation (torch_tensorrt.dynamo)
60+
if not arg_inputs and not inputs:
61+
raise AssertionError("'arg_input' or 'input' should not be None.")
62+
63+
elif arg_inputs and inputs:
64+
raise AssertionError(
65+
"'arg_input' and 'input' should not be used at the same time."
66+
)
67+
arg_inputs = inputs or arg_inputs
68+
69+
if kwarg_inputs is None:
70+
kwarg_inputs = {}
71+
5572
debug = kwargs.get("debug", DEBUG)
5673
if debug:
5774
set_log_level(logger.parent, logging.DEBUG)
5875

5976
device = to_torch_device(kwargs.get("device", default_device()))
60-
torch_inputs = get_torch_inputs(inputs, device)
77+
torch_arg_inputs = get_torch_inputs(arg_inputs, device)
78+
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
6179
dynamic_shapes = []
62-
for input in inputs:
80+
for input in arg_inputs: # type: ignore
6381
if isinstance(input, Input) and input.shape_mode == Input._ShapeMode.DYNAMIC:
6482
min_shape = input.shape["min_shape"]
6583
opt_shape = input.shape["opt_shape"]
@@ -78,6 +96,11 @@ def trace(
7896

7997
dynamic_shapes.append(dynamic_dims)
8098

81-
exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=tuple(dynamic_shapes))
99+
exp_program = export(
100+
mod,
101+
tuple(torch_arg_inputs),
102+
kwargs=torch_kwarg_inputs,
103+
dynamic_shapes=tuple(dynamic_shapes),
104+
)
82105

83106
return exp_program

0 commit comments

Comments
 (0)