Skip to content

Commit bac3099

Browse files
committed
Supported kwargs save
1 parent 00f6141 commit bac3099

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def convert_method_to_trt_engine(
351351
torchtrt_inputs = prepare_inputs(inputs)
352352
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
353353

354-
return dynamo_convert_module_to_trt_engine(
354+
return dynamo_convert_module_to_trt_engine( # type: ignore
355355
exp_program,
356356
inputs=tuple(inputs),
357357
enabled_precisions=enabled_precisions_set,
@@ -408,6 +408,7 @@ def save(
408408
*,
409409
output_format: str = "exported_program",
410410
inputs: Optional[Sequence[torch.Tensor]] = None,
411+
kwargs_inputs: Optional[dict[str, Any]] = None,
411412
retrace: bool = False,
412413
) -> None:
413414
"""
@@ -428,6 +429,10 @@ def save(
428429
raise ValueError(
429430
"Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
430431
)
432+
if kwargs_inputs is not None and not all(
433+
value is not None for value in kwargs_inputs.values()
434+
):
435+
raise ValueError("kwargs should not include None.")
431436
if output_format not in accepted_formats:
432437
raise ValueError(
433438
f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript"
@@ -460,19 +465,21 @@ def save(
460465
)
461466
# The module type is torch.fx.GraphModule
462467
if output_format == "torchscript":
463-
module_ts = torch.jit.trace(module, inputs)
468+
module_ts = torch.jit.trace(
469+
module, inputs, example_kwarg_inputs=kwargs_inputs
470+
)
464471
torch.jit.save(module_ts, file_path)
465472
else:
466473
if not retrace:
467474
from torch_tensorrt.dynamo._exporter import export
468475

469-
exp_program = export(module, inputs)
476+
exp_program = export(module, inputs, kwargs_inputs)
470477
torch.export.save(exp_program, file_path)
471478
else:
472479
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
473480

474481
with enable_torchbind_tracing():
475482
exp_program = torch.export.export(
476-
module, tuple(inputs), strict=False
483+
module, tuple(inputs), kwargs=kwargs_inputs, strict=False
477484
)
478485
torch.export.save(exp_program, file_path)

py/torch_tensorrt/dynamo/_exporter.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
import operator
3-
from typing import Any, Dict, Sequence, Tuple, cast
3+
from typing import Any, Dict, Optional, Sequence, Tuple, cast
44

55
import torch
66
from torch._guards import detect_fake_mode
@@ -22,20 +22,23 @@
2222
def export(
2323
gm: torch.fx.GraphModule,
2424
inputs: Sequence[torch.Tensor],
25+
kwargs_inputs: Optional[dict[str, Any]] = None,
2526
) -> ExportedProgram:
2627
"""Export the result of TensorRT compilation into the desired output format.
2728
2829
Arguments:
2930
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
3031
inputs (torch.Tensor): Torch input tensors
3132
"""
32-
patched_module = transform(gm, inputs)
33+
patched_module = transform(gm, inputs, kwargs_inputs)
3334
exp_program = create_trt_exp_program(patched_module)
3435
return exp_program
3536

3637

3738
def transform(
38-
gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor]
39+
gm: torch.fx.GraphModule,
40+
inputs: Sequence[torch.Tensor],
41+
kwargs_inputs: Optional[dict[str, Any]] = None,
3942
) -> torch.fx.GraphModule:
4043
"""
4144
Transforms the graphmodule by inlining Pytorch and TensorRT submodules.
@@ -53,7 +56,7 @@ def transform(
5356
gm = copy.deepcopy(gm)
5457

5558
# Run shape analysis
56-
_, outputs_map = partitioning.run_shape_analysis(gm, inputs)
59+
_, outputs_map = partitioning.run_shape_analysis(gm, inputs, kwargs_inputs)
5760

5861
# Inline TensorRT submodules
5962
inline_trt_modules(gm, outputs_map)

py/torch_tensorrt/dynamo/partitioning/common.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
129129

130130

131131
def run_shape_analysis(
132-
parent_module: torch.fx.GraphModule, inputs: Sequence[Input]
132+
parent_module: torch.fx.GraphModule,
133+
inputs: Sequence[Input],
134+
kwargs_inputs: Optional[dict[str, Any]] = None,
133135
) -> Tuple[Dict[Any, Sequence[Any]], Dict[Any, Sequence[Any]]]:
134136
submod_inputs_shape_map: Dict[Any, Sequence[Any]] = {}
135137
submod_outputs_shape_map: Dict[Any, Sequence[Any]] = {}
@@ -149,7 +151,7 @@ def get_submodule_io(
149151
for name, _ in parent_module.named_children():
150152
submodule = getattr(parent_module, name)
151153
handle = submodule.register_forward_hook(get_submodule_io)
152-
parent_module(*inputs)
154+
parent_module(*inputs, **kwargs_inputs)
153155
handle.remove()
154156
submod_inputs_shape_map[name] = (
155157
[input.shape for input in sub_inputs]

0 commit comments

Comments
 (0)