4
4
from typing import Any , Callable , Dict , List , NamedTuple , Optional , Sequence , Set
5
5
6
6
import numpy as np
7
+ import tensorrt as trt
7
8
import torch
8
9
import torch .fx
9
10
from torch .fx .node import _get_qualified_name
25
26
from torch_tensorrt .fx .observer import Observer
26
27
from torch_tensorrt .logging import TRT_LOGGER
27
28
28
- import tensorrt as trt
29
29
from packaging import version
30
30
31
31
_LOGGER : logging .Logger = logging .getLogger (__name__ )
@@ -313,8 +313,10 @@ def run(
313
313
)
314
314
timing_cache = self ._create_timing_cache (builder_config , existing_cache )
315
315
316
- engine = self .builder .build_serialized_network (self .ctx .net , builder_config )
317
- assert engine
316
+ serialized_engine = self .builder .build_serialized_network (
317
+ self .ctx .net , builder_config
318
+ )
319
+ assert serialized_engine
318
320
319
321
serialized_cache = (
320
322
bytearray (timing_cache .serialize ())
@@ -324,10 +326,10 @@ def run(
324
326
_LOGGER .info (
325
327
f"Build TRT engine elapsed time: { datetime .now () - build_engine_start_time } "
326
328
)
327
- _LOGGER .info (f"TRT Engine uses: { engine .nbytes } bytes of Memory" )
329
+ _LOGGER .info (f"TRT Engine uses: { serialized_engine .nbytes } bytes of Memory" )
328
330
329
331
return TRTInterpreterResult (
330
- engine , self ._input_names , self ._output_names , serialized_cache
332
+ serialized_engine , self ._input_names , self ._output_names , serialized_cache
331
333
)
332
334
333
335
def run_node (self , n : torch .fx .Node ) -> torch .fx .Node :
0 commit comments