4
4
from typing import Any , Callable , Dict , List , NamedTuple , Optional , Sequence , Set
5
5
6
6
import numpy as np
7
+ import tensorrt as trt
7
8
import torch
8
9
import torch .fx
9
10
from torch .fx .node import _get_qualified_name
25
26
from torch_tensorrt .fx .observer import Observer
26
27
from torch_tensorrt .logging import TRT_LOGGER
27
28
28
- import tensorrt as trt
29
29
from packaging import version
30
30
31
31
_LOGGER : logging .Logger = logging .getLogger (__name__ )
@@ -316,8 +316,10 @@ def run(
316
316
)
317
317
timing_cache = self ._create_timing_cache (builder_config , existing_cache )
318
318
319
- engine = self .builder .build_serialized_network (self .ctx .net , builder_config )
320
- assert engine
319
+ serialized_engine = self .builder .build_serialized_network (
320
+ self .ctx .net , builder_config
321
+ )
322
+ assert serialized_engine
321
323
322
324
serialized_cache = (
323
325
bytearray (timing_cache .serialize ())
@@ -327,10 +329,10 @@ def run(
327
329
_LOGGER .info (
328
330
f"Build TRT engine elapsed time: { datetime .now () - build_engine_start_time } "
329
331
)
330
- _LOGGER .info (f"TRT Engine uses: { engine .nbytes } bytes of Memory" )
332
+ _LOGGER .info (f"TRT Engine uses: { serialized_engine .nbytes } bytes of Memory" )
331
333
332
334
return TRTInterpreterResult (
333
- engine , self ._input_names , self ._output_names , serialized_cache
335
+ serialized_engine , self ._input_names , self ._output_names , serialized_cache
334
336
)
335
337
336
338
def run_node (self , n : torch .fx .Node ) -> torch .fx .Node :
0 commit comments