Skip to content

Commit cca4374

Browse files
committed
Added kwarg support for conver_module_to_engine
1 parent bac3099 commit cca4374

File tree

6 files changed

+58
-19
lines changed

6 files changed

+58
-19
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def convert_method_to_trt_engine(
351351
torchtrt_inputs = prepare_inputs(inputs)
352352
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
353353

354-
return dynamo_convert_module_to_trt_engine( # type: ignore
354+
return dynamo_convert_module_to_trt_engine(
355355
exp_program,
356356
inputs=tuple(inputs),
357357
enabled_precisions=enabled_precisions_set,
@@ -429,9 +429,10 @@ def save(
429429
raise ValueError(
430430
"Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
431431
)
432-
if kwargs_inputs is not None and not all(
433-
value is not None for value in kwargs_inputs.values()
434-
):
432+
if kwargs_inputs is None:
433+
kwargs_inputs = {}
434+
435+
if kwargs_inputs and not all(value is not None for value in kwargs_inputs.values()):
435436
raise ValueError("kwargs should not include None.")
436437
if output_format not in accepted_formats:
437438
raise ValueError(

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
pre_export_lowering,
3535
)
3636
from torch_tensorrt.dynamo.utils import (
37+
flatten_dict_value,
3738
get_torch_inputs,
3839
parse_complex_tensor_structs,
3940
prepare_inputs,
@@ -495,6 +496,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
495496
def convert_module_to_trt_engine(
496497
exported_program: ExportedProgram,
497498
inputs: Sequence[Any],
499+
kwarg_inputs: Optional[dict[str, Any]] = None,
498500
*,
499501
enabled_precisions: (
500502
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
@@ -608,12 +610,15 @@ def convert_module_to_trt_engine(
608610
stacklevel=2,
609611
)
610612

611-
input_list = list(inputs) if inputs is not None else []
613+
arg_input_list = list(inputs) if inputs is not None else []
612614
torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set()
615+
if kwarg_inputs is None:
616+
kwarg_inputs = {}
613617
# Prepare torch_trt inputs
614-
input_list = prepare_inputs(input_list)
618+
arg_input_list = prepare_inputs(arg_input_list)
619+
kwarg_input_list = prepare_inputs(kwarg_inputs)
620+
flattened_input_list = arg_input_list + flatten_dict_value(kwarg_input_list)
615621
device = to_torch_tensorrt_device(device)
616-
torch_inputs = get_torch_inputs(input_list, device)
617622
enabled_precisions = {dtype._from(e) for e in enabled_precisions}
618623

619624
compilation_options = {
@@ -662,8 +667,15 @@ def convert_module_to_trt_engine(
662667
# Assume converters support dynamic shapes and disable validation
663668
CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support)
664669

670+
interpreter_result = interpret_module_to_result(
671+
gm,
672+
inputs=flattened_input_list,
673+
arg_inputs=arg_input_list,
674+
kwarg_inputs=kwarg_input_list,
675+
settings=settings,
676+
)
665677
try:
666-
interpreter_result = interpret_module_to_result(gm, input_list, settings)
678+
pass
667679
except UnsupportedOperatorException:
668680
logger.error(
669681
f"Conversion of module {gm} not currently fully supported or convertible!",

py/torch_tensorrt/dynamo/_exporter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def export(
3030
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
3131
inputs (torch.Tensor): Torch input tensors
3232
"""
33+
if kwargs_inputs is None:
34+
kwargs_inputs = {}
3335
patched_module = transform(gm, inputs, kwargs_inputs)
3436
exp_program = create_trt_exp_program(patched_module)
3537
return exp_program
@@ -53,8 +55,9 @@ def transform(
5355
"""
5456
# Make a copy the graph since this function transforms the input graph and changes it's attributes.
5557
# This transformed graph is meant to be consumed by `create_trt_exp_program`
58+
if kwargs_inputs is None:
59+
kwargs_inputs = {}
5660
gm = copy.deepcopy(gm)
57-
5861
# Run shape analysis
5962
_, outputs_map = partitioning.run_shape_analysis(gm, inputs, kwargs_inputs)
6063

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import io
44
import logging
5-
from typing import List, Sequence
5+
from typing import Any, List, Optional, Sequence
66

77
import torch
88
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
@@ -27,12 +27,16 @@ def infer_module_output_dtypes(
2727
module: torch.fx.GraphModule,
2828
inputs: Sequence[Input],
2929
device: Device,
30+
kwarg_inputs: Optional[dict[str, Any]] = None,
3031
truncate_double: bool = False,
3132
) -> List[dtype]:
3233
with maybe_disable_fake_tensor_mode():
3334
torch_inputs = get_torch_inputs(inputs, device)
35+
if kwarg_inputs is None:
36+
kwarg_inputs = {}
37+
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
3438
module = module.to(device.to(torch.device))
35-
module_outputs = module(*torch_inputs)
39+
module_outputs = module(*torch_inputs, **torch_kwarg_inputs)
3640
if not isinstance(module_outputs, (list, tuple)):
3741
module_outputs = [module_outputs]
3842

@@ -62,6 +66,8 @@ def interpret_module_to_result(
6266
module: torch.fx.GraphModule,
6367
inputs: Sequence[Input],
6468
settings: CompilationSettings = CompilationSettings(),
69+
arg_inputs: Optional[Sequence[Input]] = None,
70+
kwarg_inputs: Optional[dict[str, Any]] = None,
6571
) -> TRTInterpreterResult:
6672
"""Interpret an FX module to a TRTInterpreterResult
6773
Args:
@@ -71,12 +77,22 @@ def interpret_module_to_result(
7177
Returns:
7278
TRTInterpreterResult
7379
"""
74-
output_dtypes = infer_module_output_dtypes(
75-
module,
76-
inputs,
77-
settings.device,
78-
truncate_double=settings.truncate_double,
79-
)
80+
if arg_inputs is not None:
81+
output_dtypes = infer_module_output_dtypes(
82+
module,
83+
arg_inputs,
84+
settings.device,
85+
kwarg_inputs=kwarg_inputs,
86+
truncate_double=settings.truncate_double,
87+
)
88+
else:
89+
# args and kwargs are combined and flattened to one list
90+
output_dtypes = infer_module_output_dtypes(
91+
module,
92+
inputs,
93+
settings.device,
94+
truncate_double=settings.truncate_double,
95+
)
8096

8197
interpreter = TRTInterpreter(
8298
module,

py/torch_tensorrt/dynamo/partitioning/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def get_submodule_io(
147147
sub_outputs = outputs
148148
return
149149

150+
if kwargs_inputs is None:
151+
kwargs_inputs = {}
150152
# Iterate through submodules (both Torch and TRT) and store IO shapes
151153
for name, _ in parent_module.named_children():
152154
submodule = getattr(parent_module, name)

tests/py/dynamo/models/test_models_export_kwargs.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# type: ignore
2+
import os
3+
import tempfile
24
import unittest
35

46
import pytest
@@ -62,12 +64,15 @@ def forward(self, x, b=5, c=None, d=None):
6264
# trt_mod = torchtrt.compile(model, **compile_spec)
6365

6466
exp_program = torch.export.export(model, args=tuple(args), kwargs=kwargs)
65-
trt_mod = torchtrt.dynamo.compile(exp_program, **compile_spec)
66-
cos_sim = cosine_similarity(model(*args, **kwargs), trt_mod(*args, **kwargs)[0])
67+
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
68+
cos_sim = cosine_similarity(model(*args, **kwargs), trt_gm(*args, **kwargs)[0])
6769
assertions.assertTrue(
6870
cos_sim > COSINE_THRESHOLD,
6971
msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
7072
)
7173

74+
# Save the module
75+
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
76+
torchtrt.save(trt_gm, trt_ep_path, inputs=args, kwargs_inputs=kwargs)
7277
# Clean up model env
7378
torch._dynamo.reset()

0 commit comments

Comments
 (0)