Skip to content

Commit c1eb9c3

Browse files
committed
cherry pick of #2832
1 parent 25a04ae commit c1eb9c3

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,10 @@ def run(
317317
)
318318
timing_cache = self._create_timing_cache(builder_config, existing_cache)
319319

320-
engine = self.builder.build_serialized_network(self.ctx.net, builder_config)
321-
assert engine
320+
serialized_engine = self.builder.build_serialized_network(
321+
self.ctx.net, builder_config
322+
)
323+
assert serialized_engine
322324

323325
serialized_cache = (
324326
bytearray(timing_cache.serialize())
@@ -328,10 +330,10 @@ def run(
328330
_LOGGER.info(
329331
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
330332
)
331-
_LOGGER.info(f"TRT Engine uses: {engine.nbytes} bytes of Memory")
333+
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")
332334

333335
return TRTInterpreterResult(
334-
engine, self._input_names, self._output_names, serialized_cache
336+
serialized_engine, self._input_names, self._output_names, serialized_cache
335337
)
336338

337339
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]
3030

3131
def __init__(
3232
self,
33-
engine: trt.ICudaEngine,
33+
engine: bytes,
3434
input_names: Optional[List[str]] = None,
3535
output_names: Optional[List[str]] = None,
3636
target_device: Device = Device._current_device(),
@@ -61,9 +61,9 @@ def _initialize(self) -> None:
6161
self.engine = runtime.deserialize_cuda_engine(self.engine)
6262
self.context = self.engine.create_execution_context()
6363

64-
assert (
65-
self.engine.num_io_tensors // self.engine.num_optimization_profiles
66-
) == (len(self.input_names) + len(self.output_names))
64+
assert self.engine.num_io_tensors == (
65+
len(self.input_names) + len(self.output_names)
66+
)
6767

6868
self.input_dtypes = [
6969
dtype._from(self.engine.get_tensor_dtype(input_name))

0 commit comments

Comments
 (0)