Skip to content

Commit ab5da54

Browse files
committed
refactor
1 parent 0024c85 commit ab5da54

File tree

8 files changed

+227
-278
lines changed

8 files changed

+227
-278
lines changed

examples/dynamo/engine_caching_bert_example.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ def compile_bert(iterations=3):
2929
torch._dynamo.reset()
3030

3131
if i == 0:
32-
save_engine_cache = False
33-
load_engine_cache = False
32+
cache_built_engines = False
33+
reuse_cached_engines = False
3434
else:
35-
save_engine_cache = True
36-
load_engine_cache = True
35+
cache_built_engines = True
36+
reuse_cached_engines = True
3737

3838
start.record()
3939
compilation_kwargs = {
@@ -43,8 +43,9 @@ def compile_bert(iterations=3):
4343
"debug": False,
4444
"min_block_size": 1,
4545
"make_refitable": True,
46-
"save_engine_cache": save_engine_cache,
47-
"load_engine_cache": load_engine_cache,
46+
"cache_built_engines": cache_built_engines,
47+
"reuse_cached_engines": reuse_cached_engines,
48+
"engine_cache_dir": "/tmp/torch_trt_bert_engine_cache",
4849
"engine_cache_size": 1 << 30, # 1GB
4950
}
5051
optimized_model = torch.compile(
Lines changed: 30 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
import ast
2-
import logging
31
import os
4-
from typing import List, Optional, Tuple
2+
from typing import Optional
53

64
import numpy as np
75
import torch
@@ -10,9 +8,6 @@
108
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
119
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache
1210

13-
_LOGGER: logging.Logger = logging.getLogger(__name__)
14-
15-
1611
np.random.seed(0)
1712
torch.manual_seed(0)
1813
size = (100, 3, 224, 224)
@@ -49,11 +44,11 @@ def dynamo_path(iterations=3):
4944
inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]
5045
remove_timing_cache() # remove timing cache for engine caching messurement
5146
if i == 0:
52-
save_engine_cache = False
53-
load_engine_cache = False
47+
cache_built_engines = False
48+
reuse_cached_engines = False
5449
else:
55-
save_engine_cache = True
56-
load_engine_cache = True
50+
cache_built_engines = True
51+
reuse_cached_engines = True
5752

5853
start.record()
5954
trt_gm = torch_trt.dynamo.compile(
@@ -64,8 +59,8 @@ def dynamo_path(iterations=3):
6459
debug=debug,
6560
min_block_size=min_block_size,
6661
make_refitable=True,
67-
save_engine_cache=save_engine_cache,
68-
load_engine_cache=load_engine_cache,
62+
cache_built_engines=cache_built_engines,
63+
reuse_cached_engines=reuse_cached_engines,
6964
engine_cache_size=1 << 30, # 1GB
7065
)
7166
end.record()
@@ -79,60 +74,36 @@ def dynamo_path(iterations=3):
7974
class MyEngineCache(BaseEngineCache):
8075
def __init__(
8176
self,
82-
engine_cache_size: int,
8377
engine_cache_dir: str,
8478
) -> None:
85-
self.total_engine_cache_size = engine_cache_size
86-
self.available_engine_cache_size = engine_cache_size
8779
self.engine_cache_dir = engine_cache_dir
8880

8981
def save(
9082
self,
9183
hash: str,
92-
serialized_engine: bytes,
93-
input_names: List[str],
94-
output_names: List[str],
95-
) -> bool:
84+
blob: bytes,
85+
prefix: str = "blob",
86+
):
9687
path = os.path.join(
9788
self.engine_cache_dir,
98-
f"{hash}/engine--{input_names}--{output_names}.trt",
89+
f"{prefix}_{hash}.bin",
9990
)
100-
try:
101-
os.makedirs(os.path.dirname(path), exist_ok=True)
102-
with open(path, "wb") as f:
103-
f.write(serialized_engine)
104-
except Exception as e:
105-
_LOGGER.warning(f"Failed to save the TRT engine to {path}: {e}")
106-
return False
107-
108-
_LOGGER.info(f"A TRT engine was cached to {path}")
109-
serialized_engine_size = int(serialized_engine.nbytes)
110-
self.available_engine_cache_size -= serialized_engine_size
111-
return True
112-
113-
def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]:
114-
directory = os.path.join(self.engine_cache_dir, hash)
115-
if os.path.exists(directory):
116-
engine_list = os.listdir(directory)
117-
assert (
118-
len(engine_list) == 1
119-
), f"There are more than one engine {engine_list} under {directory}."
120-
path = os.path.join(directory, engine_list[0])
121-
input_names_str, output_names_str = (
122-
engine_list[0].split(".trt")[0].split("--")[1:]
123-
)
124-
input_names = ast.literal_eval(input_names_str)
125-
output_names = ast.literal_eval(output_names_str)
91+
os.makedirs(path, exist_ok=True)
92+
with open(path, "wb") as f:
93+
f.write(blob)
94+
95+
def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]:
96+
path = os.path.join(self.engine_cache_dir, f"{prefix}_{hash}.bin")
97+
if os.path.exists(path):
12698
with open(path, "rb") as f:
127-
serialized_engine = f.read()
128-
return serialized_engine, input_names, output_names
129-
else:
130-
return None, [], []
99+
blob = f.read()
100+
return blob
101+
return None
131102

132103

133104
def compile_path(iterations=3):
134105
times = []
135-
engine_cache = MyEngineCache(200 * (1 << 20), "/tmp/your_dir")
106+
engine_cache = MyEngineCache("/tmp/your_dir")
136107
start = torch.cuda.Event(enable_timing=True)
137108
end = torch.cuda.Event(enable_timing=True)
138109

@@ -147,11 +118,11 @@ def compile_path(iterations=3):
147118
torch._dynamo.reset()
148119

149120
if i == 0:
150-
save_engine_cache = False
151-
load_engine_cache = False
121+
cache_built_engines = False
122+
reuse_cached_engines = False
152123
else:
153-
save_engine_cache = True
154-
load_engine_cache = True
124+
cache_built_engines = True
125+
reuse_cached_engines = True
155126

156127
start.record()
157128
compiled_model = torch.compile(
@@ -163,9 +134,9 @@ def compile_path(iterations=3):
163134
"debug": debug,
164135
"min_block_size": min_block_size,
165136
"make_refitable": True,
166-
"save_engine_cache": save_engine_cache,
167-
"load_engine_cache": load_engine_cache,
168-
"engine_cache_instance": engine_cache, # use custom engine cache
137+
"cache_built_engines": cache_built_engines,
138+
"reuse_cached_engines": reuse_cached_engines,
139+
"custom_engine_cache": engine_cache, # use custom engine cache
169140
},
170141
)
171142
compiled_model(*inputs) # trigger the compilation
@@ -178,4 +149,4 @@ def compile_path(iterations=3):
178149

179150
if __name__ == "__main__":
180151
dynamo_path()
181-
compile_path()
152+
# compile_path()

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
dryrun_stats_display,
1919
parse_non_trt_nodes,
2020
)
21-
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache, EngineCache
21+
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache, DiskEngineCache
2222
from torch_tensorrt.dynamo.conversion import (
2323
CompilationSettings,
2424
UnsupportedOperatorException,
@@ -84,11 +84,11 @@ def compile(
8484
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
8585
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
8686
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
87-
save_engine_cache: bool = _defaults.SAVE_ENGINE_CACHE,
88-
load_engine_cache: bool = _defaults.LOAD_ENGINE_CACHE,
87+
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
88+
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
8989
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
9090
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
91-
engine_cache_instance: Optional[BaseEngineCache] = None,
91+
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
9292
**kwargs: Any,
9393
) -> torch.fx.GraphModule:
9494
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -154,11 +154,11 @@ def compile(
154154
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
155155
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
156156
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
157-
save_engine_cache (bool): Whether to save the compiled TRT engines to hard disk
158-
load_engine_cache (bool): Whether to load the compiled TRT engines from hard disk
157+
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
158+
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
159159
engine_cache_dir (str): Directory to store the cached TRT engines
160160
engine_cache_size (int): Maximum hard-disk space to use for the engine cache
161-
engine_cache_instance (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache
161+
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
162162
**kwargs: Any,
163163
Returns:
164164
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -235,10 +235,9 @@ def compile(
235235
gm = post_lowering(gm)
236236
logger.debug("Lowered Input graph: " + str(gm.graph))
237237

238-
if engine_cache_instance is None:
239-
engine_cache_instance = EngineCacheInstanceCreator.get_creator(
240-
engine_cache_size, engine_cache_dir
241-
).engine_cache_instance
238+
if cache_built_engines or reuse_cached_engines:
239+
if custom_engine_cache is None:
240+
custom_engine_cache = DiskEngineCache(engine_cache_dir, engine_cache_size)
242241

243242
compilation_options = {
244243
"enabled_precisions": (
@@ -273,11 +272,9 @@ def compile(
273272
"hardware_compatible": hardware_compatible,
274273
"timing_cache_path": timing_cache_path,
275274
"lazy_engine_init": lazy_engine_init,
276-
"save_engine_cache": save_engine_cache,
277-
"load_engine_cache": load_engine_cache,
278-
"engine_cache_dir": engine_cache_dir,
279-
"engine_cache_size": engine_cache_size,
280-
"engine_cache_instance": engine_cache_instance,
275+
"cache_built_engines": cache_built_engines,
276+
"reuse_cached_engines": reuse_cached_engines,
277+
"custom_engine_cache": custom_engine_cache,
281278
}
282279

283280
settings = CompilationSettings(**compilation_options)
@@ -724,21 +721,3 @@ def convert_exported_program_to_serialized_trt_engine(
724721

725722
serialized_engine: bytes = interpreter_result.serialized_engine
726723
return serialized_engine
727-
728-
729-
class EngineCacheInstanceCreator:
730-
engine_cache_creator = None
731-
732-
def __init__(self, engine_cache_size: int, engine_cache_dir: str) -> None:
733-
self.engine_cache_instance = EngineCache(
734-
engine_cache_size=engine_cache_size,
735-
engine_cache_dir=engine_cache_dir,
736-
)
737-
738-
@classmethod
739-
def get_creator(
740-
cls, engine_cache_size: int, engine_cache_dir: str
741-
) -> EngineCacheInstanceCreator:
742-
if cls.engine_cache_creator is None:
743-
cls.engine_cache_creator = cls(engine_cache_size, engine_cache_dir)
744-
return cls.engine_cache_creator

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torch
55
from torch_tensorrt._Device import Device
66
from torch_tensorrt._enums import EngineCapability, dtype
7-
from torch_tensorrt.dynamo._engine_caching import EngineCache
87

98
ENABLED_PRECISIONS = {dtype.f32}
109
DEBUG = False
@@ -36,13 +35,11 @@
3635
tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin"
3736
)
3837
LAZY_ENGINE_INIT = False
39-
SAVE_ENGINE_CACHE = True
40-
LOAD_ENGINE_CACHE = True
38+
CACHE_BUILT_ENGINES = True
39+
REUSE_CACHED_ENGINES = True
4140
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
4241
ENGINE_CACHE_SIZE = 1073741824
43-
ENGINE_CACHE_INSTANCE = EngineCache(
44-
engine_cache_size=ENGINE_CACHE_SIZE, engine_cache_dir=ENGINE_CACHE_DIR
45-
)
42+
CUSTOM_ENGINE_CACHE = None
4643

4744

4845
def default_device() -> Device:

0 commit comments

Comments
 (0)