Skip to content

Commit cae4104

Browse files
committed
Added kwarg support for conver_module_to_engine
1 parent 8518042 commit cae4104

File tree

6 files changed

+59
-21
lines changed

6 files changed

+59
-21
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def convert_method_to_trt_engine(
351351
torchtrt_inputs = prepare_inputs(inputs)
352352
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
353353

354-
return dynamo_convert_module_to_trt_engine( # type: ignore
354+
return dynamo_convert_module_to_trt_engine(
355355
exp_program,
356356
inputs=tuple(inputs),
357357
enabled_precisions=enabled_precisions_set,
@@ -429,9 +429,10 @@ def save(
429429
raise ValueError(
430430
"Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
431431
)
432-
if kwargs_inputs is not None and not all(
433-
value is not None for value in kwargs_inputs.values()
434-
):
432+
if kwargs_inputs is None:
433+
kwargs_inputs = {}
434+
435+
if kwargs_inputs and not all(value is not None for value in kwargs_inputs.values()):
435436
raise ValueError("kwargs should not include None.")
436437
if output_format not in accepted_formats:
437438
raise ValueError(

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
pre_export_lowering,
3535
)
3636
from torch_tensorrt.dynamo.utils import (
37+
flatten_dict_value,
3738
get_torch_inputs,
3839
parse_complex_tensor_structs,
3940
prepare_inputs,
@@ -481,6 +482,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
481482
def convert_module_to_trt_engine(
482483
exported_program: ExportedProgram,
483484
inputs: Sequence[Any],
485+
kwarg_inputs: Optional[dict[str, Any]] = None,
484486
*,
485487
enabled_precisions: (
486488
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
@@ -594,12 +596,15 @@ def convert_module_to_trt_engine(
594596
stacklevel=2,
595597
)
596598

597-
input_list = list(inputs) if inputs is not None else []
599+
arg_input_list = list(inputs) if inputs is not None else []
598600
torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set()
601+
if kwarg_inputs is None:
602+
kwarg_inputs = {}
599603
# Prepare torch_trt inputs
600-
input_list = prepare_inputs(input_list)
604+
arg_input_list = prepare_inputs(arg_input_list)
605+
kwarg_input_list = prepare_inputs(kwarg_inputs)
606+
flattened_input_list = arg_input_list + flatten_dict_value(kwarg_input_list)
601607
device = to_torch_tensorrt_device(device)
602-
torch_inputs = get_torch_inputs(input_list, device)
603608
enabled_precisions = {dtype._from(e) for e in enabled_precisions}
604609

605610
compilation_options = {
@@ -648,8 +653,15 @@ def convert_module_to_trt_engine(
648653
# Assume converters support dynamic shapes and disable validation
649654
CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support)
650655

656+
interpreter_result = interpret_module_to_result(
657+
gm,
658+
inputs=flattened_input_list,
659+
arg_inputs=arg_input_list,
660+
kwarg_inputs=kwarg_input_list,
661+
settings=settings,
662+
)
651663
try:
652-
interpreter_result = interpret_module_to_result(gm, input_list, settings)
664+
pass
653665
except UnsupportedOperatorException:
654666
logger.error(
655667
f"Conversion of module {gm} not currently fully supported or convertible!",

py/torch_tensorrt/dynamo/_exporter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def export(
3030
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
3131
inputs (torch.Tensor): Torch input tensors
3232
"""
33+
if kwargs_inputs is None:
34+
kwargs_inputs = {}
3335
patched_module = transform(gm, inputs, kwargs_inputs)
3436
exp_program = create_trt_exp_program(patched_module)
3537
return exp_program
@@ -53,8 +55,9 @@ def transform(
5355
"""
5456
# Make a copy the graph since this function transforms the input graph and changes it's attributes.
5557
# This transformed graph is meant to be consumed by `create_trt_exp_program`
58+
if kwargs_inputs is None:
59+
kwargs_inputs = {}
5660
gm = copy.deepcopy(gm)
57-
5861
# Run shape analysis
5962
_, outputs_map = partitioning.run_shape_analysis(gm, inputs, kwargs_inputs)
6063

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
import io
44
import logging
5-
from typing import List, Sequence
5+
from typing import Any, List, Optional, Sequence
66

7+
import tensorrt as trt
78
import torch
89
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
910
from torch_tensorrt._Device import Device
@@ -18,21 +19,23 @@
1819
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1920
from torch_tensorrt.dynamo.utils import get_torch_inputs
2021

21-
import tensorrt as trt
22-
2322
logger = logging.getLogger(__name__)
2423

2524

2625
def infer_module_output_dtypes(
2726
module: torch.fx.GraphModule,
2827
inputs: Sequence[Input],
2928
device: Device,
29+
kwarg_inputs: Optional[dict[str, Any]] = None,
3030
truncate_double: bool = False,
3131
) -> List[dtype]:
3232
with maybe_disable_fake_tensor_mode():
3333
torch_inputs = get_torch_inputs(inputs, device)
34+
if kwarg_inputs is None:
35+
kwarg_inputs = {}
36+
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
3437
module = module.to(device.to(torch.device))
35-
module_outputs = module(*torch_inputs)
38+
module_outputs = module(*torch_inputs, **torch_kwarg_inputs)
3639
if not isinstance(module_outputs, (list, tuple)):
3740
module_outputs = [module_outputs]
3841

@@ -62,6 +65,8 @@ def interpret_module_to_result(
6265
module: torch.fx.GraphModule,
6366
inputs: Sequence[Input],
6467
settings: CompilationSettings = CompilationSettings(),
68+
arg_inputs: Optional[Sequence[Input]] = None,
69+
kwarg_inputs: Optional[dict[str, Any]] = None,
6570
) -> TRTInterpreterResult:
6671
"""Interpret an FX module to a TRTInterpreterResult
6772
Args:
@@ -71,12 +76,22 @@ def interpret_module_to_result(
7176
Returns:
7277
TRTInterpreterResult
7378
"""
74-
output_dtypes = infer_module_output_dtypes(
75-
module,
76-
inputs,
77-
settings.device,
78-
truncate_double=settings.truncate_double,
79-
)
79+
if arg_inputs is not None:
80+
output_dtypes = infer_module_output_dtypes(
81+
module,
82+
arg_inputs,
83+
settings.device,
84+
kwarg_inputs=kwarg_inputs,
85+
truncate_double=settings.truncate_double,
86+
)
87+
else:
88+
# args and kwargs are combined and flattened to one list
89+
output_dtypes = infer_module_output_dtypes(
90+
module,
91+
inputs,
92+
settings.device,
93+
truncate_double=settings.truncate_double,
94+
)
8095

8196
interpreter = TRTInterpreter(
8297
module,

py/torch_tensorrt/dynamo/partitioning/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def get_submodule_io(
147147
sub_outputs = outputs
148148
return
149149

150+
if kwargs_inputs is None:
151+
kwargs_inputs = {}
150152
# Iterate through submodules (both Torch and TRT) and store IO shapes
151153
for name, _ in parent_module.named_children():
152154
submodule = getattr(parent_module, name)

tests/py/dynamo/models/test_models_export_kwargs.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# type: ignore
2+
import os
3+
import tempfile
24
import unittest
35

46
import pytest
@@ -62,12 +64,15 @@ def forward(self, x, b=5, c=None, d=None):
6264
# trt_mod = torchtrt.compile(model, **compile_spec)
6365

6466
exp_program = torch.export.export(model, args=tuple(args), kwargs=kwargs)
65-
trt_mod = torchtrt.dynamo.compile(exp_program, **compile_spec)
66-
cos_sim = cosine_similarity(model(*args, **kwargs), trt_mod(*args, **kwargs)[0])
67+
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
68+
cos_sim = cosine_similarity(model(*args, **kwargs), trt_gm(*args, **kwargs)[0])
6769
assertions.assertTrue(
6870
cos_sim > COSINE_THRESHOLD,
6971
msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
7072
)
7173

74+
# Save the module
75+
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
76+
torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwargs_inputs=kwargs)
7277
# Clean up model env
7378
torch._dynamo.reset()

0 commit comments

Comments
 (0)