Skip to content

Commit 366cd31

Browse files
committed
fix: Address review comments
- Ensure imports of utilities reference new directory structure - Update the test cases to reflect the changes to `prepare_inputs.py` - Add non-breaking functionality to `_Input` class - Rename Python TRT runtime - Reword runtime detection log message
1 parent f3c7fc7 commit 366cd31

File tree

8 files changed

+62
-36
lines changed

8 files changed

+62
-36
lines changed

py/torch_tensorrt/_Input.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -302,47 +302,58 @@ def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple:
302302
return result_domain
303303

304304
@classmethod
305-
def from_tensor(cls, t: torch.Tensor) -> "Input":
305+
def from_tensor(
306+
cls, t: torch.Tensor, disable_memory_format_check: bool = False
307+
) -> "Input":
306308
"""
307309
Produce a Input which contains the information of the given PyTorch tensor.
308310
309311
Args:
310312
tensor (torch.Tensor): A PyTorch tensor.
313+
disable_memory_format_check (bool): Whether to validate the memory formats of input tensors
311314
312315
Returns:
313316
A Input object.
314317
"""
315-
if not any(
316-
[
317-
t.is_contiguous(memory_format=torch.contiguous_format),
318-
t.is_contiguous(memory_format=torch.channels_last),
319-
]
318+
if not (
319+
t.is_contiguous(memory_format=torch.contiguous_format)
320+
or t.is_contiguous(memory_format=torch.channels_last)
321+
or disable_memory_format_check
320322
):
321323
raise ValueError(
322324
"Tensor does not have a supported memory format, supported formats are contiguous or channel_last"
323325
)
324326
frmt = (
325327
torch.contiguous_format
326-
if t.is_contiguous(memory_format=torch.contiguous_format)
328+
if (
329+
t.is_contiguous(memory_format=torch.contiguous_format)
330+
or disable_memory_format_check
331+
)
327332
else torch.channels_last
328333
)
329334
return cls(shape=t.shape, dtype=t.dtype, format=frmt)
330335

331336
@classmethod
332-
def from_tensors(cls, ts: torch.Tensor) -> List["Input"]:
337+
def from_tensors(
338+
cls, ts: torch.Tensor, disable_memory_format_check: bool = False
339+
) -> List["Input"]:
333340
"""
334341
Produce a list of Inputs which contain
335342
the information of all the given PyTorch tensors.
336343
337344
Args:
338345
tensors (Iterable[torch.Tensor]): A list of PyTorch tensors.
346+
disable_memory_format_check (bool): Whether to validate the memory formats of input tensors
339347
340348
Returns:
341349
A list of Inputs.
342350
"""
343351

344352
assert isinstance(ts, (list, tuple))
345-
return [cls.from_tensor(t) for t in ts]
353+
return [
354+
cls.from_tensor(t, disable_memory_format_check=disable_memory_format_check)
355+
for t in ts
356+
]
346357

347358
def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor:
348359
"""

py/torch_tensorrt/dynamo/conversion/conversion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Sequence, Union
22
import torch
33
import io
4-
from torch_tensorrt.dynamo.runtime import TRTModule
4+
from torch_tensorrt.dynamo.runtime import _PythonTorchTRTModule
55
from torch_tensorrt.dynamo import CompilationSettings
66
from torch_tensorrt import Input
77
from torch_tensorrt.dynamo.conversion import TRTInterpreter
@@ -23,7 +23,7 @@ def convert_module(
2323
settings: Compilation settings
2424
name: TRT engine name
2525
Returns:
26-
TRTModule or TRTModuleNext
26+
_PythonTorchTRTModule or TorchTensorRTModule
2727
"""
2828
# Specify module output data types to ensure TRT output types agree with
2929
# that of the equivalent Torch module
@@ -35,7 +35,7 @@ def convert_module(
3535
output_dtypes = list(output.dtype for output in module_outputs)
3636
interpreter = TRTInterpreter(
3737
module,
38-
Input.from_tensors(inputs),
38+
Input.from_tensors(inputs, disable_memory_format_check=True),
3939
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
4040
output_dtypes=output_dtypes,
4141
)
@@ -53,7 +53,7 @@ def convert_module(
5353
)
5454

5555
if settings.use_python_runtime:
56-
return TRTModule(
56+
return _PythonTorchTRTModule(
5757
engine=interpreter_result.engine,
5858
input_names=interpreter_result.input_names,
5959
output_names=interpreter_result.output_names,

py/torch_tensorrt/dynamo/runtime/_PythonTorchTRTModule.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from torch_tensorrt.fx.utils import unified_dtype_converter, Frameworks
77

88

9-
class TRTModule(torch.nn.Module):
10-
"""TRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.
9+
class PythonTorchTRTModule(torch.nn.Module):
10+
"""PythonTorchTRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.
1111
1212
This module is backed by the Torch-TensorRT runtime and is only compatibile with
1313
FX / Dynamo / Python deployments. This module cannot be serialized to torchscript via torch.jit.trace for C++ deployment.
@@ -16,8 +16,8 @@ class TRTModule(torch.nn.Module):
1616
def __init__(
1717
self, engine=None, input_names=None, output_names=None, cuda_graph_batch_size=-1
1818
):
19-
super(TRTModule, self).__init__()
20-
self._register_state_dict_hook(TRTModule._on_state_dict)
19+
super(PythonTorchTRTModule, self).__init__()
20+
self._register_state_dict_hook(PythonTorchTRTModule._on_state_dict)
2121
self.engine = engine
2222
self.input_names = input_names
2323
self.output_names = output_names
@@ -94,7 +94,7 @@ def _initialize(self):
9494

9595
def _check_initialized(self):
9696
if not self.initialized:
97-
raise RuntimeError("TRTModule is not initialized.")
97+
raise RuntimeError("PythonTorchTRTModule is not initialized.")
9898

9999
def _on_state_dict(self, state_dict, prefix, local_metadata):
100100
self._check_initialized()
@@ -138,10 +138,12 @@ def __setstate__(self, state):
138138
self.context = self.engine.create_execution_context()
139139

140140
def forward(self, *inputs):
141-
with torch.autograd.profiler.record_function("TRTModule:Forward"):
141+
with torch.autograd.profiler.record_function("PythonTorchTRTModule:Forward"):
142142
self._check_initialized()
143143

144-
with torch.autograd.profiler.record_function("TRTModule:ProcessInputs"):
144+
with torch.autograd.profiler.record_function(
145+
"PythonTorchTRTModule:ProcessInputs"
146+
):
145147
assert len(inputs) == len(
146148
self.input_names
147149
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
@@ -176,7 +178,9 @@ def forward(self, *inputs):
176178
f"Expect {self.input_shapes[i]}, got {inputs[i].size()[1:]}."
177179
)
178180

179-
with torch.autograd.profiler.record_function("TRTModule:ProcessOutputs"):
181+
with torch.autograd.profiler.record_function(
182+
"PythonTorchTRTModule:ProcessOutputs"
183+
):
180184
# create output tensors
181185
outputs: List[torch.Tensor] = []
182186

@@ -207,7 +211,9 @@ def forward(self, *inputs):
207211
)
208212
bindings[idx] = output.data_ptr()
209213

210-
with torch.autograd.profiler.record_function("TRTModule:TensorRTRuntime"):
214+
with torch.autograd.profiler.record_function(
215+
"PythonTorchTRTModule:TensorRTRuntime"
216+
):
211217
if self.engine.has_implicit_batch_dimension:
212218
self.context.execute_async(
213219
batch_size, bindings, torch.cuda.current_stream().cuda_stream

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
output_binding_names: List[str] = [],
3737
target_device: Device = Device._current_device(),
3838
):
39-
"""__init__ method for torch_tensorrt.TorchTensorRTModule
39+
"""__init__ method for torch_tensorrt.dynamo.runtime._TorchTensorRTModule.TorchTensorRTModule
4040
4141
Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
4242
a PyTorch ``torch.nn.Module`` around it.
@@ -61,11 +61,11 @@ def __init__(
6161
engine_bytes.write(trt_engine.serialize())
6262
engine_str = engine_bytes.getvalue()
6363
64-
trt_module = TRTModule(
64+
trt_module = TorchTensorRTModule(
6565
engine_str,
66-
engine_name="my_module",
67-
input_names=["x"],
68-
output_names=["output"],
66+
name="my_module",
67+
input_binding_names=["x"],
68+
output_binding_names=["output"],
6969
)
7070
7171
"""
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from ._PythonTorchTRTModule import TRTModule
1+
from ._PythonTorchTRTModule import PythonTorchTRTModule
22
from ._TorchTensorRTModule import TorchTensorRTModule

py/torch_tensorrt/dynamo/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool
3636
reason = "since import failed, C++ dependency not installed"
3737

3838
logger.info(
39-
f"Using {'Python' if using_python_runtime else 'C++'} {reason} TRT Runtime"
39+
f"Using {'Python-only' if using_python_runtime else 'Default'} Torch-TRT Runtime ({reason})"
4040
)
4141

4242
return using_python_runtime

tests/py/dynamo/backend/test_compiler_utils.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,22 @@ def test_prepare_trt_device(self):
2424
class TestPrepareInputs(unittest.TestCase):
2525
def test_prepare_single_tensor_input(self):
2626
inputs = [torch.ones((4, 4))]
27-
prepared_inputs = prepare_inputs(inputs)
27+
prepared_inputs_trt, prepared_inputs_torch = prepare_inputs(inputs)
2828
self.assertTrue(
29-
same_output_format(inputs, prepared_inputs, enforce_tensor_type=False)
29+
same_output_format(inputs, prepared_inputs_trt, enforce_tensor_type=False)
30+
)
31+
self.assertTrue(
32+
same_output_format(inputs, prepared_inputs_torch, enforce_tensor_type=False)
3033
)
3134

3235
def test_prepare_trt_input(self):
3336
inputs = [torch_tensorrt.Input(shape=(4, 3), dtype=torch.float)]
34-
prepared_inputs = prepare_inputs(inputs)
37+
prepared_inputs_trt, prepared_inputs_torch = prepare_inputs(inputs)
38+
self.assertTrue(
39+
same_output_format(inputs, prepared_inputs_trt, enforce_tensor_type=False)
40+
)
3541
self.assertTrue(
36-
same_output_format(inputs, prepared_inputs, enforce_tensor_type=False)
42+
same_output_format(inputs, prepared_inputs_torch, enforce_tensor_type=False)
3743
)
3844

3945
def test_prepare_mixed_type_compound_tensor_input(self):
@@ -47,9 +53,12 @@ def test_prepare_mixed_type_compound_tensor_input(self):
4753
(torch.rand((5, 1)), torch_tensorrt.Input(shape=(2, 3))),
4854
),
4955
}
50-
prepared_inputs = prepare_inputs(inputs)
56+
prepared_inputs_trt, prepared_inputs_torch = prepare_inputs(inputs)
57+
self.assertTrue(
58+
same_output_format(inputs, prepared_inputs_trt, enforce_tensor_type=False)
59+
)
5160
self.assertTrue(
52-
same_output_format(inputs, prepared_inputs, enforce_tensor_type=False)
61+
same_output_format(inputs, prepared_inputs_torch, enforce_tensor_type=False)
5362
)
5463

5564

tests/py/ts/api/test_classes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
import torch_tensorrt as torchtrt
3-
from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule
3+
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule
44
import torch
55
import torchvision.models as models
66
import copy

0 commit comments

Comments
 (0)