Skip to content

Commit fe2cd6a

Browse files
committed
feat: Lazy engine initialization
Allows engines to not be setup immediately after compilation but all at once before the module is returned back to the user. Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent c0a2bea commit fe2cd6a

File tree

11 files changed

+325
-87
lines changed

11 files changed

+325
-87
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def compile(
7979
dryrun: bool = _defaults.DRYRUN,
8080
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
8181
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
82+
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
8283
**kwargs: Any,
8384
) -> torch.fx.GraphModule:
8485
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -141,6 +142,7 @@ def compile(
141142
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
142143
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
143144
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
145+
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
144146
**kwargs: Any,
145147
Returns:
146148
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -236,6 +238,7 @@ def compile(
236238
"dryrun": dryrun,
237239
"hardware_compatible": hardware_compatible,
238240
"timing_cache_path": timing_cache_path,
241+
"lazy_engine_init": lazy_engine_init,
239242
}
240243

241244
settings = CompilationSettings(**compilation_options)
@@ -454,6 +457,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
454457
# Replace all FX Modules with TRT Modules
455458
for name, trt_module in trt_modules.items():
456459
setattr(partitioned_module, name, trt_module)
460+
if settings.lazy_engine_init:
461+
getattr(partitioned_module, name).setup_engine()
457462

458463
# Reset settings object to user specification after fallback to global partitioning mode
459464
if fast_partitioner_failed:
@@ -647,10 +652,5 @@ def convert_module_to_trt_engine(
647652
exc_info=True,
648653
)
649654

650-
import io
651-
652-
with io.BytesIO() as engine_bytes:
653-
engine_bytes.write(interpreter_result.engine)
654-
engine_bytearray = engine_bytes.getvalue()
655-
656-
return engine_bytearray
655+
serialized_engine: bytes = interpreter_result.serialized_engine
656+
return serialized_engine

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
HARDWARE_COMPATIBLE = False
3333
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
3434
TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin")
35+
LAZY_ENGINE_INIT = False
3536

3637

3738
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ENABLED_PRECISIONS,
1717
ENGINE_CAPABILITY,
1818
HARDWARE_COMPATIBLE,
19+
LAZY_ENGINE_INIT,
1920
MAKE_REFITABLE,
2021
MAX_AUX_STREAMS,
2122
MIN_BLOCK_SIZE,
@@ -104,3 +105,4 @@ class CompilationSettings:
104105
dryrun: Union[bool, str] = DRYRUN
105106
hardware_compatible: bool = HARDWARE_COMPATIBLE
106107
timing_cache_path: str = TIMING_CACHE_PATH
108+
lazy_engine_init: bool = LAZY_ENGINE_INIT

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import io
12
import logging
23
import os
34
import warnings
45
from datetime import datetime
56
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple
67

78
import numpy as np
8-
import tensorrt as trt
99
import torch
1010
import torch.fx
1111
from torch.fx.node import _get_qualified_name
@@ -29,6 +29,7 @@
2929
from torch_tensorrt.fx.observer import Observer
3030
from torch_tensorrt.logging import TRT_LOGGER
3131

32+
import tensorrt as trt
3233
from packaging import version
3334

3435
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -43,7 +44,7 @@ class UnsupportedOperatorException(RuntimeError):
4344

4445

4546
class TRTInterpreterResult(NamedTuple):
46-
engine: Any
47+
serialized_engine: bytes
4748
input_names: Sequence[str]
4849
output_names: Sequence[str]
4950

@@ -358,9 +359,11 @@ def run(
358359
builder_config, self.compilation_settings.timing_cache_path
359360
)
360361

361-
return TRTInterpreterResult(
362-
serialized_engine, self._input_names, self._output_names
363-
)
362+
with io.BytesIO() as engine_bytes:
363+
engine_bytes.write(serialized_engine)
364+
engine_str = engine_bytes.getvalue()
365+
366+
return TRTInterpreterResult(engine_str, self._input_names, self._output_names)
364367

365368
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
366369
self._cur_node_name = get_node_name(n)

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import io
43
import logging
54
from typing import List, Sequence
65

@@ -102,33 +101,30 @@ def convert_module(
102101
settings: Compilation settings
103102
name: TRT engine name
104103
Returns:
105-
_PythonTorchTensorRTModule or TorchTensorRTModule
104+
PythonTorchTensorRTModule or TorchTensorRTModule
106105
"""
107106
interpreter_result = interpret_module_to_result(module, inputs, settings)
108107

109-
if settings.use_python_runtime or not ENABLED_FEATURES.torch_tensorrt_runtime:
110-
if not settings.use_python_runtime:
111-
logger.info(
112-
"Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available"
113-
)
114-
return PythonTorchTensorRTModule(
115-
engine=interpreter_result.engine,
116-
input_names=list(interpreter_result.input_names),
117-
output_names=list(interpreter_result.output_names),
118-
settings=settings,
119-
)
108+
rt_cls = PythonTorchTensorRTModule
109+
110+
if ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime:
120111

121-
else:
122112
from torch_tensorrt.dynamo.runtime import TorchTensorRTModule
123113

124-
with io.BytesIO() as engine_bytes:
125-
engine_bytes.write(interpreter_result.engine)
126-
engine_str = engine_bytes.getvalue()
114+
rt_cls = TorchTensorRTModule
115+
116+
elif (
117+
not ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime
118+
):
127119

128-
return TorchTensorRTModule(
129-
serialized_engine=engine_str,
130-
name=name,
131-
input_binding_names=list(interpreter_result.input_names),
132-
output_binding_names=list(interpreter_result.output_names),
133-
settings=settings,
120+
logger.info(
121+
"Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available"
134122
)
123+
124+
return rt_cls(
125+
serialized_engine=interpreter_result.serialized_engine,
126+
input_binding_names=list(interpreter_result.input_names),
127+
output_binding_names=list(interpreter_result.output_names),
128+
name=name,
129+
settings=settings,
130+
)

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from contextlib import nullcontext
55
from typing import Any, Dict, List, Optional, Sequence, Tuple
66

7-
import tensorrt as trt
87
import torch
8+
import torch_tensorrt
99
from torch.nn import Module
1010
from torch_tensorrt._Device import Device
1111
from torch_tensorrt._enums import dtype
@@ -18,7 +18,7 @@
1818
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
1919
from torch_tensorrt.logging import TRT_LOGGER
2020

21-
import torch_tensorrt
21+
import tensorrt as trt
2222

2323
logger = logging.getLogger(__name__)
2424

@@ -30,11 +30,15 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]
3030
FX / Dynamo / Python deployments. This module cannot be serialized to torchscript via torch.jit.trace for C++ deployment.
3131
"""
3232

33+
defer_engine_setup = False
34+
3335
def __init__(
3436
self,
35-
engine: bytes,
36-
input_names: Optional[List[str]] = None,
37-
output_names: Optional[List[str]] = None,
37+
serialized_engine: bytes,
38+
input_binding_names: Optional[List[str]] = None,
39+
output_binding_names: Optional[List[str]] = None,
40+
*,
41+
name: str = "",
3842
settings: CompilationSettings = CompilationSettings(),
3943
):
4044
super(PythonTorchTensorRTModule, self).__init__()
@@ -43,6 +47,7 @@ def __init__(
4347
# Run multi-gpu device check to validate engine instantiation
4448
multi_gpu_device_check()
4549

50+
self.name = name
4651
self.input_buffers: List[torch.Tensor] = []
4752
self.output_buffers: List[torch.Tensor] = []
4853
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
@@ -55,9 +60,13 @@ def __init__(
5560
# Unused currently - to be used by Dynamic Shape support implementation
5661
self.memory_pool = None
5762

58-
self.engine = engine
59-
self.input_names = input_names if input_names is not None else []
60-
self.output_names = output_names if output_names is not None else []
63+
self.serialized_engine = serialized_engine
64+
self.input_names = (
65+
input_binding_names if input_binding_names is not None else []
66+
)
67+
self.output_names = (
68+
output_binding_names if output_binding_names is not None else []
69+
)
6170
self.initialized = False
6271
self.target_device_id = (
6372
settings.device.gpu_id
@@ -69,12 +78,15 @@ def __init__(
6978
)
7079
self.profiling_enabled = settings.debug if settings.debug is not None else False
7180
self.settings = settings
72-
self._initialize()
81+
self.engine = None
82+
83+
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
84+
self.setup_engine()
7385

74-
def _initialize(self) -> None:
86+
def setup_engine(self) -> None:
7587
self.initialized = True
7688
runtime = trt.Runtime(TRT_LOGGER)
77-
self.engine = runtime.deserialize_cuda_engine(self.engine)
89+
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
7890
self.context = self.engine.create_execution_context()
7991

8092
assert self.engine.num_io_tensors == (
@@ -114,8 +126,7 @@ def _check_initialized(self) -> None:
114126
raise RuntimeError("PythonTorchTensorRTModule is not initialized.")
115127

116128
def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None:
117-
self._check_initialized()
118-
state_dict[prefix + "engine"] = bytearray(self.engine.serialize())
129+
state_dict[prefix + "engine"] = bytearray(self.serialized_engine)
119130
state_dict[prefix + "input_names"] = self.input_names
120131
state_dict[prefix + "output_names"] = self.output_names
121132

@@ -129,17 +140,13 @@ def _load_from_state_dict(
129140
unexpected_keys: Any,
130141
error_msgs: Any,
131142
) -> None:
132-
engine_bytes = state_dict[prefix + "engine"]
143+
self.serialized_engine = state_dict[prefix + "engine"]
144+
self.input_names = state_dict[prefix + "input_names"]
145+
self.output_names = state_dict[prefix + "output_names"]
133146

134147
# Run multi-gpu device check to validate engine instantiation
135148
multi_gpu_device_check()
136-
137-
runtime = trt.Runtime(TRT_LOGGER)
138-
self.engine = runtime.deserialize_cuda_engine(engine_bytes)
139-
140-
self.input_names = state_dict[prefix + "input_names"]
141-
self.output_names = state_dict[prefix + "output_names"]
142-
self._initialize()
149+
self.setup_engine()
143150

144151
def __getstate__(self) -> Dict[str, Any]:
145152
state = self.__dict__.copy()

0 commit comments

Comments
 (0)