Skip to content

Added kwarg support for dynamo.compile #2970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docsrc/py_api/dynamo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ Functions

.. autofunction:: refit_module_weights


Classes
--------

Expand Down
2 changes: 1 addition & 1 deletion examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
new_trt_gm = refit_module_weights(
compiled_module=compiled_trt_ep,
new_weight_module=exp_program2,
inputs=inputs,
arg_inputs=inputs,
)

# Check the output
Expand Down
97 changes: 78 additions & 19 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def compile(
module: Any,
ir: str = "default",
inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None,
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
kwarg_inputs: Optional[dict[Any, Any]] = None,
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
**kwargs: Any,
) -> (
Expand Down Expand Up @@ -180,14 +182,16 @@ def compile(
), # Dynamic input shape for input #2
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
]

arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)

Returns:
torch.nn.Module: Compiled Module, when run it will execute via TensorRT
"""

input_list = inputs if inputs is not None else []
enabled_precisions_set: Set[dtype | torch.dtype] = (
enabled_precisions
Expand Down Expand Up @@ -238,17 +242,33 @@ def compile(
return compiled_fx_module
elif target_ir == _IRType.dynamo:
# Prepare torch and torchtrt inputs
if not arg_inputs and not inputs:
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")

elif arg_inputs and inputs:
raise AssertionError(
"'arg_inputs' and 'inputs' should not be used at the same time."
)
arg_inputs = inputs or arg_inputs

if kwarg_inputs is None:
kwarg_inputs = {}

from torch_tensorrt.dynamo.utils import prepare_inputs

if not isinstance(input_list, collections.abc.Sequence):
input_list = [input_list]
if not isinstance(arg_inputs, collections.abc.Sequence):
arg_inputs = [arg_inputs] # type: ignore

# Export the module
torchtrt_inputs = prepare_inputs(input_list)
exp_program = dynamo_trace(module, torchtrt_inputs, **kwargs)
torchtrt_arg_inputs = prepare_inputs(arg_inputs)
torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)

exp_program = dynamo_trace(
module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, **kwargs
)
trt_graph_module = dynamo_compile(
exp_program,
inputs=torchtrt_inputs,
arg_inputs=torchtrt_arg_inputs,
enabled_precisions=enabled_precisions_set,
**kwargs,
)
Expand Down Expand Up @@ -280,7 +300,9 @@ def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any:
def convert_method_to_trt_engine(
module: Any,
method_name: str = "forward",
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None,
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
kwarg_inputs: Optional[dict[Any, Any]] = None,
ir: str = "default",
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
**kwargs: Any,
Expand Down Expand Up @@ -309,6 +331,8 @@ def convert_method_to_trt_engine(
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
]

arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)
Expand All @@ -330,7 +354,7 @@ def convert_method_to_trt_engine(
ts_mod = torch.jit.script(module)
serialized_engine: bytes = ts_convert_method_to_trt_engine(
ts_mod,
inputs=inputs,
inputs=arg_inputs,
method_name=method_name,
enabled_precisions=enabled_precisions_set,
**kwargs,
Expand All @@ -342,18 +366,35 @@ def convert_method_to_trt_engine(
)
elif target_ir == _IRType.dynamo:
# Prepare torch and torchtrt inputs
if not arg_inputs and not inputs:
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")

elif arg_inputs and inputs:
raise AssertionError(
"'arg_inputs' and 'inputs' should not be used at the same time."
)
arg_inputs = arg_inputs or inputs

if kwarg_inputs is None:
kwarg_inputs = {}

from torch_tensorrt.dynamo.utils import prepare_inputs

if not isinstance(inputs, collections.abc.Sequence):
inputs = [inputs]
if not isinstance(arg_inputs, collections.abc.Sequence):
arg_inputs = [arg_inputs] # type: ignore

# Export the module
torchtrt_inputs = prepare_inputs(inputs)
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
torchtrt_arg_inputs = prepare_inputs(arg_inputs)
torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)

exp_program = torch_tensorrt.dynamo.trace(
module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs**kwargs
)

return dynamo_convert_module_to_trt_engine(
exp_program,
inputs=tuple(inputs),
arg_inputs=tuple(arg_inputs),
kwarg_inputs=torchtrt_kwarg_inputs,
enabled_precisions=enabled_precisions_set,
**kwargs,
)
Expand Down Expand Up @@ -408,6 +449,8 @@ def save(
*,
output_format: str = "exported_program",
inputs: Optional[Sequence[torch.Tensor]] = None,
arg_inputs: Optional[Sequence[torch.Tensor]] = None,
kwarg_inputs: Optional[dict[str, Any]] = None,
retrace: bool = False,
) -> None:
"""
Expand All @@ -416,18 +459,32 @@ def save(
Arguments:
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)): Compiled Torch-TensorRT module
inputs (torch.Tensor): Torch input tensors
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
output_format (str): Format to save the model. Options include exported_program | torchscript.
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
This flag is experimental for now.
"""
module_type = _parse_module_type(module)
accepted_formats = {"exported_program", "torchscript"}
if inputs is not None and not all(
isinstance(input, torch.Tensor) for input in inputs
if arg_inputs is not None and not all(
isinstance(input, torch.Tensor) for input in arg_inputs
):
raise ValueError(
"Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
)
if arg_inputs and inputs:
raise AssertionError(
"'arg_inputs' and 'inputs' should not be used at the same time."
)

arg_inputs = inputs or arg_inputs

if kwarg_inputs is None:
kwarg_inputs = {}

if kwarg_inputs and any(value is None for value in kwarg_inputs.values()):
raise ValueError("kwargs should not include None.")
if output_format not in accepted_formats:
raise ValueError(
f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript"
Expand All @@ -454,25 +511,27 @@ def save(
else:
torch.export.save(module, file_path)
elif module_type == _ModuleType.fx:
if inputs is None:
if arg_inputs is None:
raise ValueError(
"Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model"
)
# The module type is torch.fx.GraphModule
if output_format == "torchscript":
module_ts = torch.jit.trace(module, inputs)
module_ts = torch.jit.trace(
module, arg_inputs, example_kwarg_inputs=kwarg_inputs
)
torch.jit.save(module_ts, file_path)
else:
if not retrace:
from torch_tensorrt.dynamo._exporter import export

exp_program = export(module, inputs)
exp_program = export(module, arg_inputs, kwarg_inputs)
torch.export.save(exp_program, file_path)
else:
from torch._higher_order_ops.torchbind import enable_torchbind_tracing

with enable_torchbind_tracing():
exp_program = torch.export.export(
module, tuple(inputs), strict=False
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
)
torch.export.save(exp_program, file_path)
Loading
Loading