Skip to content

Commit d26028f

Browse files
committed
Fixed a small typo in the test file
1 parent 9c029a5 commit d26028f

File tree

4 files changed

+15
-8
lines changed

4 files changed

+15
-8
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -622,9 +622,9 @@ def convert_module_to_trt_engine(
622622
if kwarg_inputs is None:
623623
kwarg_inputs = {}
624624
# Prepare torch_trt inputs
625-
arg_inputs = prepare_inputs(arg_inputs)
625+
arg_input_list = list(prepare_inputs(arg_inputs))
626626
kwarg_input_list = prepare_inputs(kwarg_inputs)
627-
flattened_input_list = arg_inputs + flatten_dict_value(kwarg_input_list)
627+
flattened_input_list = arg_input_list + flatten_dict_value(kwarg_input_list)
628628
device = to_torch_tensorrt_device(device)
629629
enabled_precisions = {dtype._from(e) for e in enabled_precisions}
630630

@@ -678,7 +678,7 @@ def convert_module_to_trt_engine(
678678
interpreter_result = interpret_module_to_result(
679679
gm,
680680
inputs=flattened_input_list,
681-
arg_inputs=arg_inputs,
681+
arg_inputs=arg_input_list,
682682
kwarg_inputs=kwarg_input_list,
683683
settings=settings,
684684
)

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ def refit_module_weights(
157157
This compiled_module should be compmiled by torch_tensorrt.dynamo.compile
158158
or load it from disk using trt.load.
159159
new_weight_module: exported program with the updated weights. This one should have the same model architecture as the compiled module.
160-
inputs: sample inputs
160+
arg_inputs: sample arg inputs. Optional, needed if output check
161+
kwarg_inputs: sample kwarg inputs. Optional, needed if output check
161162
verify_output: whether to verify output of refitted module
162163
Returns:
163164
A new compiled TensorRT module that has the updated weights.

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
from typing import Any, List, Optional, Sequence
66

7+
import tensorrt as trt
78
import torch
89
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
910
from torch_tensorrt._Device import Device
@@ -18,8 +19,6 @@
1819
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1920
from torch_tensorrt.dynamo.utils import get_torch_inputs
2021

21-
import tensorrt as trt
22-
2322
logger = logging.getLogger(__name__)
2423

2524

@@ -30,6 +29,10 @@ def infer_module_output_dtypes(
3029
kwarg_inputs: Optional[dict[str, Any]] = None,
3130
truncate_double: bool = False,
3231
) -> List[dtype]:
32+
"""
33+
inputs can be either arg_inputs or flattened input list. If it is flattened list, kwarg_inputs
34+
should be None, as it is already included in the flattened input.
35+
"""
3336
with maybe_disable_fake_tensor_mode():
3437
torch_inputs = get_torch_inputs(inputs, device)
3538
if kwarg_inputs is None:
@@ -72,7 +75,10 @@ def interpret_module_to_result(
7275
"""Interpret an FX module to a TRTInterpreterResult
7376
Args:
7477
module: FX GraphModule to interpret
75-
inputs: Sequence of Tensors representing inputs to the module
78+
inputs: Sequence of FLATTENED Tensors representing inputs to the module. It should include both
79+
arg_inputs and kwarg_inputs, if applicable.
80+
arg_inputs: Sequence of Tensors representing inputs to the module.
81+
kwarg_inputs: A dictionary of Tensors representing inputs to the module.
7682
settings: Compilation settings
7783
Returns:
7884
TRTInterpreterResult

tests/py/dynamo/models/test_export_kwargs_serde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
453453
msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
454454
)
455455
# Change the input shape
456-
kwargs["d"][1] = torch.randn([10, 2])
456+
kwargs["e"][1] = torch.randn([10, 2]).to("cuda")
457457
cos_sim = cosine_similarity(model(*args, **kwargs), trt_gm(*args, **kwargs)[0])
458458
assertions.assertTrue(
459459
cos_sim > COSINE_THRESHOLD,

0 commit comments

Comments
 (0)