Skip to content

Commit 62673bd

Browse files
committed
fix bugs
1 parent cd61e54 commit 62673bd

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
55

66
import numpy as np
7+
import tensorrt as trt
78
import torch
89
import torch.fx
910
from torch.fx.node import _get_qualified_name
@@ -25,7 +26,6 @@
2526
from torch_tensorrt.fx.observer import Observer
2627
from torch_tensorrt.logging import TRT_LOGGER
2728

28-
import tensorrt as trt
2929
from packaging import version
3030

3131
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -313,8 +313,10 @@ def run(
313313
)
314314
timing_cache = self._create_timing_cache(builder_config, existing_cache)
315315

316-
engine = self.builder.build_serialized_network(self.ctx.net, builder_config)
317-
assert engine
316+
serialized_engine = self.builder.build_serialized_network(
317+
self.ctx.net, builder_config
318+
)
319+
assert serialized_engine
318320

319321
serialized_cache = (
320322
bytearray(timing_cache.serialize())
@@ -324,10 +326,10 @@ def run(
324326
_LOGGER.info(
325327
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
326328
)
327-
_LOGGER.info(f"TRT Engine uses: {engine.nbytes} bytes of Memory")
329+
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")
328330

329331
return TRTInterpreterResult(
330-
engine, self._input_names, self._output_names, serialized_cache
332+
serialized_engine, self._input_names, self._output_names, serialized_cache
331333
)
332334

333335
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]
2929

3030
def __init__(
3131
self,
32-
engine: trt.ICudaEngine,
32+
engine: trt.tensorrt.IHostMemory,
3333
input_names: Optional[List[str]] = None,
3434
output_names: Optional[List[str]] = None,
3535
target_device: Device = Device._current_device(),
@@ -60,9 +60,9 @@ def _initialize(self) -> None:
6060
self.engine = runtime.deserialize_cuda_engine(self.engine)
6161
self.context = self.engine.create_execution_context()
6262

63-
assert (
64-
self.engine.num_io_tensors // self.engine.num_optimization_profiles
65-
) == (len(self.input_names) + len(self.output_names))
63+
assert self.engine.num_io_tensors == (
64+
len(self.input_names) + len(self.output_names)
65+
)
6666

6767
self.input_dtypes = [
6868
dtype._from(self.engine.get_tensor_dtype(input_name))

0 commit comments

Comments
 (0)