Skip to content

Commit c78491e

Browse files
committed
Added kwarg support for conver_module_to_engine
1 parent 2d592d2 commit c78491e

File tree

6 files changed

+59
-20
lines changed

6 files changed

+59
-20
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def convert_method_to_trt_engine(
351351
torchtrt_inputs = prepare_inputs(inputs)
352352
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
353353

354-
return dynamo_convert_module_to_trt_engine( # type: ignore
354+
return dynamo_convert_module_to_trt_engine(
355355
exp_program,
356356
inputs=tuple(inputs),
357357
enabled_precisions=enabled_precisions_set,
@@ -421,9 +421,10 @@ def save(
421421
raise ValueError(
422422
"Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
423423
)
424-
if kwargs_inputs is not None and not all(
425-
value is not None for value in kwargs_inputs.values()
426-
):
424+
if kwargs_inputs is None:
425+
kwargs_inputs = {}
426+
427+
if kwargs_inputs and not all(value is not None for value in kwargs_inputs.values()):
427428
raise ValueError("kwargs should not include None.")
428429
if output_format not in accepted_formats:
429430
raise ValueError(

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
pre_export_lowering,
3535
)
3636
from torch_tensorrt.dynamo.utils import (
37+
flatten_dict_value,
3738
get_torch_inputs,
3839
parse_complex_tensor_structs,
3940
prepare_inputs,
@@ -480,7 +481,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
480481

481482
def convert_module_to_trt_engine(
482483
exported_program: ExportedProgram,
483-
inputs: Tuple[Any, ...],
484+
inputs: Sequence[Any],
485+
kwarg_inputs: Optional[dict[str, Any]] = None,
484486
*,
485487
enabled_precisions: (
486488
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
@@ -592,12 +594,15 @@ def convert_module_to_trt_engine(
592594
stacklevel=2,
593595
)
594596

595-
input_list = list(inputs) if inputs is not None else []
597+
arg_input_list = list(inputs) if inputs is not None else []
596598
torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set()
599+
if kwarg_inputs is None:
600+
kwarg_inputs = {}
597601
# Prepare torch_trt inputs
598-
input_list = prepare_inputs(input_list)
602+
arg_input_list = prepare_inputs(arg_input_list)
603+
kwarg_input_list = prepare_inputs(kwarg_inputs)
604+
flattened_input_list = arg_input_list + flatten_dict_value(kwarg_input_list)
599605
device = to_torch_tensorrt_device(device)
600-
torch_inputs = get_torch_inputs(input_list, device)
601606
enabled_precisions = {dtype._from(e) for e in enabled_precisions}
602607

603608
compilation_options = {
@@ -646,8 +651,15 @@ def convert_module_to_trt_engine(
646651
# Assume converters support dynamic shapes and disable validation
647652
CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support)
648653

654+
interpreter_result = interpret_module_to_result(
655+
gm,
656+
inputs=flattened_input_list,
657+
arg_inputs=arg_input_list,
658+
kwarg_inputs=kwarg_input_list,
659+
settings=settings,
660+
)
649661
try:
650-
interpreter_result = interpret_module_to_result(gm, input_list, settings)
662+
pass
651663
except UnsupportedOperatorException:
652664
logger.error(
653665
f"Conversion of module {gm} not currently fully supported or convertible!",

py/torch_tensorrt/dynamo/_exporter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def export(
3030
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
3131
inputs (torch.Tensor): Torch input tensors
3232
"""
33+
if kwargs_inputs is None:
34+
kwargs_inputs = {}
3335
patched_module = transform(gm, inputs, kwargs_inputs)
3436
exp_program = create_trt_exp_program(patched_module)
3537
return exp_program
@@ -53,8 +55,9 @@ def transform(
5355
"""
5456
# Make a copy the graph since this function transforms the input graph and changes it's attributes.
5557
# This transformed graph is meant to be consumed by `create_trt_exp_program`
58+
if kwargs_inputs is None:
59+
kwargs_inputs = {}
5660
gm = copy.deepcopy(gm)
57-
5861
# Run shape analysis
5962
_, outputs_map = partitioning.run_shape_analysis(gm, inputs, kwargs_inputs)
6063

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import io
44
import logging
5-
from typing import List, Sequence
5+
from typing import Any, List, Optional, Sequence
66

77
import tensorrt as trt
88
import torch
@@ -26,12 +26,16 @@ def infer_module_output_dtypes(
2626
module: torch.fx.GraphModule,
2727
inputs: Sequence[Input],
2828
device: Device,
29+
kwarg_inputs: Optional[dict[str, Any]] = None,
2930
truncate_double: bool = False,
3031
) -> List[dtype]:
3132
with maybe_disable_fake_tensor_mode():
3233
torch_inputs = get_torch_inputs(inputs, device)
34+
if kwarg_inputs is None:
35+
kwarg_inputs = {}
36+
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
3337
module = module.to(device.to(torch.device))
34-
module_outputs = module(*torch_inputs)
38+
module_outputs = module(*torch_inputs, **torch_kwarg_inputs)
3539
if not isinstance(module_outputs, (list, tuple)):
3640
module_outputs = [module_outputs]
3741

@@ -61,6 +65,8 @@ def interpret_module_to_result(
6165
module: torch.fx.GraphModule,
6266
inputs: Sequence[Input],
6367
settings: CompilationSettings = CompilationSettings(),
68+
arg_inputs: Optional[Sequence[Input]] = None,
69+
kwarg_inputs: Optional[dict[str, Any]] = None,
6470
) -> TRTInterpreterResult:
6571
"""Interpret an FX module to a TRTInterpreterResult
6672
Args:
@@ -70,12 +76,22 @@ def interpret_module_to_result(
7076
Returns:
7177
TRTInterpreterResult
7278
"""
73-
output_dtypes = infer_module_output_dtypes(
74-
module,
75-
inputs,
76-
settings.device,
77-
truncate_double=settings.truncate_double,
78-
)
79+
if arg_inputs is not None:
80+
output_dtypes = infer_module_output_dtypes(
81+
module,
82+
arg_inputs,
83+
settings.device,
84+
kwarg_inputs=kwarg_inputs,
85+
truncate_double=settings.truncate_double,
86+
)
87+
else:
88+
# args and kwargs are combined and flattened to one list
89+
output_dtypes = infer_module_output_dtypes(
90+
module,
91+
inputs,
92+
settings.device,
93+
truncate_double=settings.truncate_double,
94+
)
7995

8096
interpreter = TRTInterpreter(
8197
module,

py/torch_tensorrt/dynamo/partitioning/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def get_submodule_io(
147147
sub_outputs = outputs
148148
return
149149

150+
if kwargs_inputs is None:
151+
kwargs_inputs = {}
150152
# Iterate through submodules (both Torch and TRT) and store IO shapes
151153
for name, _ in parent_module.named_children():
152154
submodule = getattr(parent_module, name)

tests/py/dynamo/models/test_models_export_kwargs.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# type: ignore
2+
import os
3+
import tempfile
24
import unittest
35

46
import pytest
@@ -62,12 +64,15 @@ def forward(self, x, b=5, c=None, d=None):
6264
# trt_mod = torchtrt.compile(model, **compile_spec)
6365

6466
exp_program = torch.export.export(model, args=tuple(args), kwargs=kwargs)
65-
trt_mod = torchtrt.dynamo.compile(exp_program, **compile_spec)
66-
cos_sim = cosine_similarity(model(*args, **kwargs), trt_mod(*args, **kwargs)[0])
67+
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
68+
cos_sim = cosine_similarity(model(*args, **kwargs), trt_gm(*args, **kwargs)[0])
6769
assertions.assertTrue(
6870
cos_sim > COSINE_THRESHOLD,
6971
msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
7072
)
7173

74+
# Save the module
75+
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
76+
torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwargs_inputs=kwargs)
7277
# Clean up model env
7378
torch._dynamo.reset()

0 commit comments

Comments
 (0)