Skip to content

Commit 3a0b244

Browse files
authored
Added kwarg support for dynamo.compile (#2970)
Approved by Naren.
1 parent 994ed05 commit 3a0b244

27 files changed

+935
-173
lines changed

docsrc/py_api/dynamo.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ Functions
2626

2727
.. autofunction:: refit_module_weights
2828

29-
3029
Classes
3130
--------
3231

examples/dynamo/refit_engine_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
new_trt_gm = refit_module_weights(
7979
compiled_module=compiled_trt_ep,
8080
new_weight_module=exp_program2,
81-
inputs=inputs,
81+
arg_inputs=inputs,
8282
)
8383

8484
# Check the output

py/torch_tensorrt/_compile.py

Lines changed: 78 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ def compile(
148148
module: Any,
149149
ir: str = "default",
150150
inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None,
151+
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
152+
kwarg_inputs: Optional[dict[Any, Any]] = None,
151153
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
152154
**kwargs: Any,
153155
) -> (
@@ -180,14 +182,16 @@ def compile(
180182
), # Dynamic input shape for input #2
181183
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
182184
]
183-
185+
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
186+
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
184187
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
185188
ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
186189
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)
187190
188191
Returns:
189192
torch.nn.Module: Compiled Module, when run it will execute via TensorRT
190193
"""
194+
191195
input_list = inputs if inputs is not None else []
192196
enabled_precisions_set: Set[dtype | torch.dtype] = (
193197
enabled_precisions
@@ -238,17 +242,33 @@ def compile(
238242
return compiled_fx_module
239243
elif target_ir == _IRType.dynamo:
240244
# Prepare torch and torchtrt inputs
245+
if not arg_inputs and not inputs:
246+
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
247+
248+
elif arg_inputs and inputs:
249+
raise AssertionError(
250+
"'arg_inputs' and 'inputs' should not be used at the same time."
251+
)
252+
arg_inputs = inputs or arg_inputs
253+
254+
if kwarg_inputs is None:
255+
kwarg_inputs = {}
256+
241257
from torch_tensorrt.dynamo.utils import prepare_inputs
242258

243-
if not isinstance(input_list, collections.abc.Sequence):
244-
input_list = [input_list]
259+
if not isinstance(arg_inputs, collections.abc.Sequence):
260+
arg_inputs = [arg_inputs] # type: ignore
245261

246262
# Export the module
247-
torchtrt_inputs = prepare_inputs(input_list)
248-
exp_program = dynamo_trace(module, torchtrt_inputs, **kwargs)
263+
torchtrt_arg_inputs = prepare_inputs(arg_inputs)
264+
torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)
265+
266+
exp_program = dynamo_trace(
267+
module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, **kwargs
268+
)
249269
trt_graph_module = dynamo_compile(
250270
exp_program,
251-
inputs=torchtrt_inputs,
271+
arg_inputs=torchtrt_arg_inputs,
252272
enabled_precisions=enabled_precisions_set,
253273
**kwargs,
254274
)
@@ -280,7 +300,9 @@ def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any:
280300
def convert_method_to_trt_engine(
281301
module: Any,
282302
method_name: str = "forward",
283-
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
303+
inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None,
304+
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
305+
kwarg_inputs: Optional[dict[Any, Any]] = None,
284306
ir: str = "default",
285307
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
286308
**kwargs: Any,
@@ -309,6 +331,8 @@ def convert_method_to_trt_engine(
309331
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
310332
]
311333
334+
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
335+
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
312336
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
313337
ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
314338
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)
@@ -330,7 +354,7 @@ def convert_method_to_trt_engine(
330354
ts_mod = torch.jit.script(module)
331355
serialized_engine: bytes = ts_convert_method_to_trt_engine(
332356
ts_mod,
333-
inputs=inputs,
357+
inputs=arg_inputs,
334358
method_name=method_name,
335359
enabled_precisions=enabled_precisions_set,
336360
**kwargs,
@@ -342,18 +366,35 @@ def convert_method_to_trt_engine(
342366
)
343367
elif target_ir == _IRType.dynamo:
344368
# Prepare torch and torchtrt inputs
369+
if not arg_inputs and not inputs:
370+
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
371+
372+
elif arg_inputs and inputs:
373+
raise AssertionError(
374+
"'arg_inputs' and 'inputs' should not be used at the same time."
375+
)
376+
arg_inputs = arg_inputs or inputs
377+
378+
if kwarg_inputs is None:
379+
kwarg_inputs = {}
380+
345381
from torch_tensorrt.dynamo.utils import prepare_inputs
346382

347-
if not isinstance(inputs, collections.abc.Sequence):
348-
inputs = [inputs]
383+
if not isinstance(arg_inputs, collections.abc.Sequence):
384+
arg_inputs = [arg_inputs] # type: ignore
349385

350386
# Export the module
351-
torchtrt_inputs = prepare_inputs(inputs)
352-
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
387+
torchtrt_arg_inputs = prepare_inputs(arg_inputs)
388+
torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)
389+
390+
exp_program = torch_tensorrt.dynamo.trace(
391+
module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs**kwargs
392+
)
353393

354394
return dynamo_convert_exported_program_to_serialized_trt_engine(
355395
exp_program,
356-
inputs=tuple(inputs),
396+
arg_inputs=tuple(arg_inputs),
397+
kwarg_inputs=torchtrt_kwarg_inputs,
357398
enabled_precisions=enabled_precisions_set,
358399
**kwargs,
359400
)
@@ -408,6 +449,8 @@ def save(
408449
*,
409450
output_format: str = "exported_program",
410451
inputs: Optional[Sequence[torch.Tensor]] = None,
452+
arg_inputs: Optional[Sequence[torch.Tensor]] = None,
453+
kwarg_inputs: Optional[dict[str, Any]] = None,
411454
retrace: bool = False,
412455
) -> None:
413456
"""
@@ -416,18 +459,32 @@ def save(
416459
Arguments:
417460
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)): Compiled Torch-TensorRT module
418461
inputs (torch.Tensor): Torch input tensors
462+
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
463+
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
419464
output_format (str): Format to save the model. Options include exported_program | torchscript.
420465
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
421466
This flag is experimental for now.
422467
"""
423468
module_type = _parse_module_type(module)
424469
accepted_formats = {"exported_program", "torchscript"}
425-
if inputs is not None and not all(
426-
isinstance(input, torch.Tensor) for input in inputs
470+
if arg_inputs is not None and not all(
471+
isinstance(input, torch.Tensor) for input in arg_inputs
427472
):
428473
raise ValueError(
429474
"Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
430475
)
476+
if arg_inputs and inputs:
477+
raise AssertionError(
478+
"'arg_inputs' and 'inputs' should not be used at the same time."
479+
)
480+
481+
arg_inputs = inputs or arg_inputs
482+
483+
if kwarg_inputs is None:
484+
kwarg_inputs = {}
485+
486+
if kwarg_inputs and any(value is None for value in kwarg_inputs.values()):
487+
raise ValueError("kwargs should not include None.")
431488
if output_format not in accepted_formats:
432489
raise ValueError(
433490
f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript"
@@ -454,25 +511,27 @@ def save(
454511
else:
455512
torch.export.save(module, file_path)
456513
elif module_type == _ModuleType.fx:
457-
if inputs is None:
514+
if arg_inputs is None:
458515
raise ValueError(
459516
"Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model"
460517
)
461518
# The module type is torch.fx.GraphModule
462519
if output_format == "torchscript":
463-
module_ts = torch.jit.trace(module, inputs)
520+
module_ts = torch.jit.trace(
521+
module, arg_inputs, example_kwarg_inputs=kwarg_inputs
522+
)
464523
torch.jit.save(module_ts, file_path)
465524
else:
466525
if not retrace:
467526
from torch_tensorrt.dynamo._exporter import export
468527

469-
exp_program = export(module, inputs)
528+
exp_program = export(module, arg_inputs, kwarg_inputs)
470529
torch.export.save(exp_program, file_path)
471530
else:
472531
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
473532

474533
with enable_torchbind_tracing():
475534
exp_program = torch.export.export(
476-
module, tuple(inputs), strict=False
535+
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
477536
)
478537
torch.export.save(exp_program, file_path)

0 commit comments

Comments
 (0)