Skip to content

Commit 3bbeea3

Browse files
committed
Fixed a small typo in the test file
1 parent 9c029a5 commit 3bbeea3

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -622,9 +622,9 @@ def convert_module_to_trt_engine(
622622
if kwarg_inputs is None:
623623
kwarg_inputs = {}
624624
# Prepare torch_trt inputs
625-
arg_inputs = prepare_inputs(arg_inputs)
626-
kwarg_input_list = prepare_inputs(kwarg_inputs)
627-
flattened_input_list = arg_inputs + flatten_dict_value(kwarg_input_list)
625+
arg_input_list = list(prepare_inputs(arg_inputs))
626+
kwarg_input_list = list(prepare_inputs(kwarg_inputs))
627+
flattened_input_list = arg_input_list + flatten_dict_value(kwarg_input_list)
628628
device = to_torch_tensorrt_device(device)
629629
enabled_precisions = {dtype._from(e) for e in enabled_precisions}
630630

@@ -678,7 +678,7 @@ def convert_module_to_trt_engine(
678678
interpreter_result = interpret_module_to_result(
679679
gm,
680680
inputs=flattened_input_list,
681-
arg_inputs=arg_inputs,
681+
arg_inputs=arg_input_list,
682682
kwarg_inputs=kwarg_input_list,
683683
settings=settings,
684684
)

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
from typing import Any, List, Optional, Sequence
66

7+
import tensorrt as trt
78
import torch
89
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
910
from torch_tensorrt._Device import Device
@@ -18,8 +19,6 @@
1819
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1920
from torch_tensorrt.dynamo.utils import get_torch_inputs
2021

21-
import tensorrt as trt
22-
2322
logger = logging.getLogger(__name__)
2423

2524

@@ -30,6 +29,10 @@ def infer_module_output_dtypes(
3029
kwarg_inputs: Optional[dict[str, Any]] = None,
3130
truncate_double: bool = False,
3231
) -> List[dtype]:
32+
"""
33+
inputs can be either arg_inputs or flattened input list. If it is flattened list, kwarg_inputs
34+
should be None, as it is already included in the flattened input.
35+
"""
3336
with maybe_disable_fake_tensor_mode():
3437
torch_inputs = get_torch_inputs(inputs, device)
3538
if kwarg_inputs is None:
@@ -72,7 +75,10 @@ def interpret_module_to_result(
7275
"""Interpret an FX module to a TRTInterpreterResult
7376
Args:
7477
module: FX GraphModule to interpret
75-
inputs: Sequence of Tensors representing inputs to the module
78+
inputs: Sequence of FLATTENED Tensors representing inputs to the module. It should include both
79+
arg_inputs and kwarg_inputs, if applicable.
80+
arg_inputs: Sequence of Tensors representing inputs to the module.
81+
kwarg_inputs: A dictionary of Tensors representing inputs to the module.
7682
settings: Compilation settings
7783
Returns:
7884
TRTInterpreterResult

tests/py/dynamo/models/test_export_kwargs_serde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
453453
msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
454454
)
455455
# Change the input shape
456-
kwargs["d"][1] = torch.randn([10, 2])
456+
kwargs["e"][1] = torch.randn([10, 2]).to("cuda")
457457
cos_sim = cosine_similarity(model(*args, **kwargs), trt_gm(*args, **kwargs)[0])
458458
assertions.assertTrue(
459459
cos_sim > COSINE_THRESHOLD,

0 commit comments

Comments
 (0)