Skip to content

Commit 37adede

Browse files
committed
feat: Lazy engine initialization
Allows engines to not be setup immediately after compilation but all at once before the module is returned back to the user. Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent c0a2bea commit 37adede

File tree

11 files changed

+527
-89
lines changed

11 files changed

+527
-89
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def compile(
7979
dryrun: bool = _defaults.DRYRUN,
8080
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
8181
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
82+
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
8283
**kwargs: Any,
8384
) -> torch.fx.GraphModule:
8485
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -141,6 +142,7 @@ def compile(
141142
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
142143
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
143144
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
145+
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
144146
**kwargs: Any,
145147
Returns:
146148
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -236,6 +238,7 @@ def compile(
236238
"dryrun": dryrun,
237239
"hardware_compatible": hardware_compatible,
238240
"timing_cache_path": timing_cache_path,
241+
"lazy_engine_init": lazy_engine_init,
239242
}
240243

241244
settings = CompilationSettings(**compilation_options)
@@ -454,6 +457,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
454457
# Replace all FX Modules with TRT Modules
455458
for name, trt_module in trt_modules.items():
456459
setattr(partitioned_module, name, trt_module)
460+
if settings.lazy_engine_init:
461+
getattr(partitioned_module, name).setup_engine()
457462

458463
# Reset settings object to user specification after fallback to global partitioning mode
459464
if fast_partitioner_failed:
@@ -647,10 +652,5 @@ def convert_module_to_trt_engine(
647652
exc_info=True,
648653
)
649654

650-
import io
651-
652-
with io.BytesIO() as engine_bytes:
653-
engine_bytes.write(interpreter_result.engine)
654-
engine_bytearray = engine_bytes.getvalue()
655-
656-
return engine_bytearray
655+
serialized_engine: bytes = interpreter_result.serialized_engine
656+
return serialized_engine

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
HARDWARE_COMPATIBLE = False
3333
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
3434
TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin")
35+
LAZY_ENGINE_INIT = False
3536

3637

3738
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ENABLED_PRECISIONS,
1717
ENGINE_CAPABILITY,
1818
HARDWARE_COMPATIBLE,
19+
LAZY_ENGINE_INIT,
1920
MAKE_REFITABLE,
2021
MAX_AUX_STREAMS,
2122
MIN_BLOCK_SIZE,
@@ -104,3 +105,4 @@ class CompilationSettings:
104105
dryrun: Union[bool, str] = DRYRUN
105106
hardware_compatible: bool = HARDWARE_COMPATIBLE
106107
timing_cache_path: str = TIMING_CACHE_PATH
108+
lazy_engine_init: bool = LAZY_ENGINE_INIT

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import io
12
import logging
23
import os
34
import warnings
45
from datetime import datetime
56
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple
67

78
import numpy as np
8-
import tensorrt as trt
99
import torch
1010
import torch.fx
1111
from torch.fx.node import _get_qualified_name
@@ -29,6 +29,7 @@
2929
from torch_tensorrt.fx.observer import Observer
3030
from torch_tensorrt.logging import TRT_LOGGER
3131

32+
import tensorrt as trt
3233
from packaging import version
3334

3435
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -43,7 +44,7 @@ class UnsupportedOperatorException(RuntimeError):
4344

4445

4546
class TRTInterpreterResult(NamedTuple):
46-
engine: Any
47+
serialized_engine: bytes
4748
input_names: Sequence[str]
4849
output_names: Sequence[str]
4950

@@ -358,9 +359,11 @@ def run(
358359
builder_config, self.compilation_settings.timing_cache_path
359360
)
360361

361-
return TRTInterpreterResult(
362-
serialized_engine, self._input_names, self._output_names
363-
)
362+
with io.BytesIO() as engine_bytes:
363+
engine_bytes.write(serialized_engine)
364+
engine_str = engine_bytes.getvalue()
365+
366+
return TRTInterpreterResult(engine_str, self._input_names, self._output_names)
364367

365368
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
366369
self._cur_node_name = get_node_name(n)

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import io
43
import logging
54
from typing import List, Sequence
65

@@ -102,33 +101,30 @@ def convert_module(
102101
settings: Compilation settings
103102
name: TRT engine name
104103
Returns:
105-
_PythonTorchTensorRTModule or TorchTensorRTModule
104+
PythonTorchTensorRTModule or TorchTensorRTModule
106105
"""
107106
interpreter_result = interpret_module_to_result(module, inputs, settings)
108107

109-
if settings.use_python_runtime or not ENABLED_FEATURES.torch_tensorrt_runtime:
110-
if not settings.use_python_runtime:
111-
logger.info(
112-
"Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available"
113-
)
114-
return PythonTorchTensorRTModule(
115-
engine=interpreter_result.engine,
116-
input_names=list(interpreter_result.input_names),
117-
output_names=list(interpreter_result.output_names),
118-
settings=settings,
119-
)
108+
rt_cls = PythonTorchTensorRTModule
109+
110+
if ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime:
120111

121-
else:
122112
from torch_tensorrt.dynamo.runtime import TorchTensorRTModule
123113

124-
with io.BytesIO() as engine_bytes:
125-
engine_bytes.write(interpreter_result.engine)
126-
engine_str = engine_bytes.getvalue()
114+
rt_cls = TorchTensorRTModule
115+
116+
elif (
117+
not ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime
118+
):
127119

128-
return TorchTensorRTModule(
129-
serialized_engine=engine_str,
130-
name=name,
131-
input_binding_names=list(interpreter_result.input_names),
132-
output_binding_names=list(interpreter_result.output_names),
133-
settings=settings,
120+
logger.info(
121+
"Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available"
134122
)
123+
124+
return rt_cls(
125+
serialized_engine=interpreter_result.serialized_engine,
126+
input_binding_names=list(interpreter_result.input_names),
127+
output_binding_names=list(interpreter_result.output_names),
128+
name=name,
129+
settings=settings,
130+
)

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from contextlib import nullcontext
55
from typing import Any, Dict, List, Optional, Sequence, Tuple
66

7-
import tensorrt as trt
87
import torch
8+
import torch_tensorrt
99
from torch.nn import Module
1010
from torch_tensorrt._Device import Device
1111
from torch_tensorrt._enums import dtype
@@ -18,7 +18,7 @@
1818
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
1919
from torch_tensorrt.logging import TRT_LOGGER
2020

21-
import torch_tensorrt
21+
import tensorrt as trt
2222

2323
logger = logging.getLogger(__name__)
2424

@@ -30,19 +30,49 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]
3030
FX / Dynamo / Python deployments. This module cannot be serialized to torchscript via torch.jit.trace for C++ deployment.
3131
"""
3232

33+
defer_engine_setup = False
34+
3335
def __init__(
3436
self,
35-
engine: bytes,
36-
input_names: Optional[List[str]] = None,
37-
output_names: Optional[List[str]] = None,
37+
serialized_engine: Optional[bytes] = None,
38+
input_binding_names: Optional[List[str]] = None,
39+
output_binding_names: Optional[List[str]] = None,
40+
*,
41+
name: str = "",
3842
settings: CompilationSettings = CompilationSettings(),
3943
):
44+
"""Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
45+
a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine
46+
47+
Arguments:
48+
serialized_engine (bytearray): Serialized TensorRT engine in the form of a bytearray
49+
input_binding_names (List[str]): List of input TensorRT engine binding names in the order they would be passed to the TRT modules
50+
output_binding_names (List[str]): List of output TensorRT engine binding names in the order they should be returned
51+
52+
Keyword Arguments:
53+
name (str): Name for module
54+
settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
55+
56+
Example:
57+
58+
.. code-block:: py
59+
60+
trt_module = PythonTorchTensorRTModule(
61+
engine_str,
62+
input_binding_names=["x"],
63+
output_binding_names=["output"],
64+
name="my_module",
65+
settings=CompilationSettings(device=torch.cuda.current_device)
66+
)
67+
68+
"""
4069
super(PythonTorchTensorRTModule, self).__init__()
4170
self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict)
4271

4372
# Run multi-gpu device check to validate engine instantiation
4473
multi_gpu_device_check()
4574

75+
self.name = name
4676
self.input_buffers: List[torch.Tensor] = []
4777
self.output_buffers: List[torch.Tensor] = []
4878
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
@@ -55,9 +85,13 @@ def __init__(
5585
# Unused currently - to be used by Dynamic Shape support implementation
5686
self.memory_pool = None
5787

58-
self.engine = engine
59-
self.input_names = input_names if input_names is not None else []
60-
self.output_names = output_names if output_names is not None else []
88+
self.serialized_engine = serialized_engine
89+
self.input_names = (
90+
input_binding_names if input_binding_names is not None else []
91+
)
92+
self.output_names = (
93+
output_binding_names if output_binding_names is not None else []
94+
)
6195
self.initialized = False
6296
self.target_device_id = (
6397
settings.device.gpu_id
@@ -69,12 +103,15 @@ def __init__(
69103
)
70104
self.profiling_enabled = settings.debug if settings.debug is not None else False
71105
self.settings = settings
72-
self._initialize()
106+
self.engine = None
73107

74-
def _initialize(self) -> None:
108+
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
109+
self.setup_engine()
110+
111+
def setup_engine(self) -> None:
75112
self.initialized = True
76113
runtime = trt.Runtime(TRT_LOGGER)
77-
self.engine = runtime.deserialize_cuda_engine(self.engine)
114+
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
78115
self.context = self.engine.create_execution_context()
79116

80117
assert self.engine.num_io_tensors == (
@@ -114,8 +151,7 @@ def _check_initialized(self) -> None:
114151
raise RuntimeError("PythonTorchTensorRTModule is not initialized.")
115152

116153
def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None:
117-
self._check_initialized()
118-
state_dict[prefix + "engine"] = bytearray(self.engine.serialize())
154+
state_dict[prefix + "engine"] = self.serialized_engine
119155
state_dict[prefix + "input_names"] = self.input_names
120156
state_dict[prefix + "output_names"] = self.output_names
121157

@@ -129,17 +165,13 @@ def _load_from_state_dict(
129165
unexpected_keys: Any,
130166
error_msgs: Any,
131167
) -> None:
132-
engine_bytes = state_dict[prefix + "engine"]
168+
self.serialized_engine = state_dict[prefix + "engine"]
169+
self.input_names = state_dict[prefix + "input_names"]
170+
self.output_names = state_dict[prefix + "output_names"]
133171

134172
# Run multi-gpu device check to validate engine instantiation
135173
multi_gpu_device_check()
136-
137-
runtime = trt.Runtime(TRT_LOGGER)
138-
self.engine = runtime.deserialize_cuda_engine(engine_bytes)
139-
140-
self.input_names = state_dict[prefix + "input_names"]
141-
self.output_names = state_dict[prefix + "output_names"]
142-
self._initialize()
174+
self.setup_engine()
143175

144176
def __getstate__(self) -> Dict[str, Any]:
145177
state = self.__dict__.copy()

0 commit comments

Comments
 (0)