Skip to content

Commit 78ceda5

Browse files
zewenli98laikhtewari
authored andcommitted
fix: bugs in TRT 10 upgrade (#2832)
1 parent 5045ff8 commit 78ceda5

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
55

66
import numpy as np
7+
import tensorrt as trt
78
import torch
89
import torch.fx
910
from torch.fx.node import _get_qualified_name
@@ -25,7 +26,6 @@
2526
from torch_tensorrt.fx.observer import Observer
2627
from torch_tensorrt.logging import TRT_LOGGER
2728

28-
import tensorrt as trt
2929
from packaging import version
3030

3131
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -316,8 +316,10 @@ def run(
316316
)
317317
timing_cache = self._create_timing_cache(builder_config, existing_cache)
318318

319-
engine = self.builder.build_serialized_network(self.ctx.net, builder_config)
320-
assert engine
319+
serialized_engine = self.builder.build_serialized_network(
320+
self.ctx.net, builder_config
321+
)
322+
assert serialized_engine
321323

322324
serialized_cache = (
323325
bytearray(timing_cache.serialize())
@@ -327,10 +329,10 @@ def run(
327329
_LOGGER.info(
328330
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
329331
)
330-
_LOGGER.info(f"TRT Engine uses: {engine.nbytes} bytes of Memory")
332+
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")
331333

332334
return TRTInterpreterResult(
333-
engine, self._input_names, self._output_names, serialized_cache
335+
serialized_engine, self._input_names, self._output_names, serialized_cache
334336
)
335337

336338
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]
2929

3030
def __init__(
3131
self,
32-
engine: trt.ICudaEngine,
32+
engine: bytes,
3333
input_names: Optional[List[str]] = None,
3434
output_names: Optional[List[str]] = None,
3535
target_device: Device = Device._current_device(),
@@ -60,9 +60,9 @@ def _initialize(self) -> None:
6060
self.engine = runtime.deserialize_cuda_engine(self.engine)
6161
self.context = self.engine.create_execution_context()
6262

63-
assert (
64-
self.engine.num_io_tensors // self.engine.num_optimization_profiles
65-
) == (len(self.input_names) + len(self.output_names))
63+
assert self.engine.num_io_tensors == (
64+
len(self.input_names) + len(self.output_names)
65+
)
6666

6767
self.input_dtypes = [
6868
dtype._from(self.engine.get_tensor_dtype(input_name))

0 commit comments

Comments
 (0)