Skip to content

Commit 95cc532

Browse files
authored
feat: engine caching (#2995)
1 parent 3a3d62a commit 95cc532

22 files changed

+1069
-42
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import numpy as np
2+
import torch
3+
import torch_tensorrt
4+
from engine_caching_example import remove_timing_cache
5+
from transformers import BertModel
6+
7+
np.random.seed(0)
8+
torch.manual_seed(0)
9+
10+
model = BertModel.from_pretrained("bert-base-uncased", return_dict=False).cuda().eval()
11+
inputs = [
12+
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
13+
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
14+
]
15+
16+
17+
def compile_bert(iterations=3):
18+
times = []
19+
start = torch.cuda.Event(enable_timing=True)
20+
end = torch.cuda.Event(enable_timing=True)
21+
22+
# The 1st iteration is to measure the compilation time without engine caching
23+
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
24+
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
25+
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
26+
for i in range(iterations):
27+
# remove timing cache and reset dynamo for engine caching messurement
28+
remove_timing_cache()
29+
torch._dynamo.reset()
30+
31+
if i == 0:
32+
cache_built_engines = False
33+
reuse_cached_engines = False
34+
else:
35+
cache_built_engines = True
36+
reuse_cached_engines = True
37+
38+
start.record()
39+
compilation_kwargs = {
40+
"use_python_runtime": False,
41+
"enabled_precisions": {torch.float},
42+
"truncate_double": True,
43+
"debug": False,
44+
"min_block_size": 1,
45+
"make_refitable": True,
46+
"cache_built_engines": cache_built_engines,
47+
"reuse_cached_engines": reuse_cached_engines,
48+
"engine_cache_dir": "/tmp/torch_trt_bert_engine_cache",
49+
"engine_cache_size": 1 << 30, # 1GB
50+
}
51+
optimized_model = torch.compile(
52+
model,
53+
backend="torch_tensorrt",
54+
options=compilation_kwargs,
55+
)
56+
optimized_model(*inputs)
57+
end.record()
58+
torch.cuda.synchronize()
59+
times.append(start.elapsed_time(end))
60+
61+
print("-----compile bert-----> compilation time:\n", times, "milliseconds")
62+
63+
64+
if __name__ == "__main__":
65+
compile_bert()
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import os
2+
from typing import Optional
3+
4+
import numpy as np
5+
import torch
6+
import torch_tensorrt as torch_trt
7+
import torchvision.models as models
8+
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
9+
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache
10+
11+
np.random.seed(0)
12+
torch.manual_seed(0)
13+
14+
model = models.resnet18(pretrained=True).eval().to("cuda")
15+
enabled_precisions = {torch.float}
16+
debug = False
17+
min_block_size = 1
18+
use_python_runtime = False
19+
20+
21+
def remove_timing_cache(path=TIMING_CACHE_PATH):
22+
if os.path.exists(path):
23+
os.remove(path)
24+
25+
26+
def dynamo_compile(iterations=3):
27+
times = []
28+
start = torch.cuda.Event(enable_timing=True)
29+
end = torch.cuda.Event(enable_timing=True)
30+
31+
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
32+
# Mark the dim0 of inputs as dynamic
33+
batch = torch.export.Dim("batch", min=1, max=200)
34+
exp_program = torch.export.export(
35+
model, args=example_inputs, dynamic_shapes={"x": {0: batch}}
36+
)
37+
38+
# The 1st iteration is to measure the compilation time without engine caching
39+
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
40+
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
41+
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
42+
for i in range(iterations):
43+
inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]
44+
remove_timing_cache() # remove timing cache just for engine caching messurement
45+
if i == 0:
46+
cache_built_engines = False
47+
reuse_cached_engines = False
48+
else:
49+
cache_built_engines = True
50+
reuse_cached_engines = True
51+
52+
start.record()
53+
trt_gm = torch_trt.dynamo.compile(
54+
exp_program,
55+
tuple(inputs),
56+
use_python_runtime=use_python_runtime,
57+
enabled_precisions=enabled_precisions,
58+
debug=debug,
59+
min_block_size=min_block_size,
60+
make_refitable=True,
61+
cache_built_engines=cache_built_engines,
62+
reuse_cached_engines=reuse_cached_engines,
63+
engine_cache_size=1 << 30, # 1GB
64+
)
65+
# output = trt_gm(*inputs)
66+
end.record()
67+
torch.cuda.synchronize()
68+
times.append(start.elapsed_time(end))
69+
70+
print("----------------dynamo_compile----------------")
71+
print("disable engine caching, used:", times[0], "ms")
72+
print("enable engine caching to cache engines, used:", times[1], "ms")
73+
print("enable engine caching to reuse engines, used:", times[2], "ms")
74+
75+
76+
# Custom Engine Cache
77+
class MyEngineCache(BaseEngineCache):
78+
def __init__(
79+
self,
80+
engine_cache_dir: str,
81+
) -> None:
82+
self.engine_cache_dir = engine_cache_dir
83+
84+
def save(
85+
self,
86+
hash: str,
87+
blob: bytes,
88+
prefix: str = "blob",
89+
):
90+
if not os.path.exists(self.engine_cache_dir):
91+
os.makedirs(self.engine_cache_dir, exist_ok=True)
92+
93+
path = os.path.join(
94+
self.engine_cache_dir,
95+
f"{prefix}_{hash}.bin",
96+
)
97+
with open(path, "wb") as f:
98+
f.write(blob)
99+
100+
def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]:
101+
path = os.path.join(self.engine_cache_dir, f"{prefix}_{hash}.bin")
102+
if os.path.exists(path):
103+
with open(path, "rb") as f:
104+
blob = f.read()
105+
return blob
106+
return None
107+
108+
109+
def torch_compile(iterations=3):
110+
times = []
111+
engine_cache = MyEngineCache("/tmp/your_dir")
112+
start = torch.cuda.Event(enable_timing=True)
113+
end = torch.cuda.Event(enable_timing=True)
114+
115+
# The 1st iteration is to measure the compilation time without engine caching
116+
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
117+
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
118+
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
119+
for i in range(iterations):
120+
inputs = [torch.rand((100, 3, 224, 224)).to("cuda")]
121+
# remove timing cache and reset dynamo just for engine caching messurement
122+
remove_timing_cache()
123+
torch._dynamo.reset()
124+
125+
if i == 0:
126+
cache_built_engines = False
127+
reuse_cached_engines = False
128+
else:
129+
cache_built_engines = True
130+
reuse_cached_engines = True
131+
132+
start.record()
133+
compiled_model = torch.compile(
134+
model,
135+
backend="tensorrt",
136+
options={
137+
"use_python_runtime": True,
138+
"enabled_precisions": enabled_precisions,
139+
"debug": debug,
140+
"min_block_size": min_block_size,
141+
"make_refitable": True,
142+
"cache_built_engines": cache_built_engines,
143+
"reuse_cached_engines": reuse_cached_engines,
144+
"custom_engine_cache": engine_cache, # use custom engine cache
145+
},
146+
)
147+
compiled_model(*inputs) # trigger the compilation
148+
end.record()
149+
torch.cuda.synchronize()
150+
times.append(start.elapsed_time(end))
151+
152+
print("----------------torch_compile----------------")
153+
print("disable engine caching, used:", times[0], "ms")
154+
print("enable engine caching to cache engines, used:", times[1], "ms")
155+
print("enable engine caching to reuse engines, used:", times[2], "ms")
156+
157+
158+
if __name__ == "__main__":
159+
dynamo_compile()
160+
torch_compile()

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
dryrun_stats_display,
1919
parse_non_trt_nodes,
2020
)
21+
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache, DiskEngineCache
2122
from torch_tensorrt.dynamo.conversion import (
2223
CompilationSettings,
2324
UnsupportedOperatorException,
@@ -82,6 +83,11 @@ def compile(
8283
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
8384
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
8485
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
86+
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
87+
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
88+
engine_cache_dir: Optional[str] = _defaults.ENGINE_CACHE_DIR,
89+
engine_cache_size: Optional[int] = _defaults.ENGINE_CACHE_SIZE,
90+
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
8591
**kwargs: Any,
8692
) -> torch.fx.GraphModule:
8793
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -147,6 +153,11 @@ def compile(
147153
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
148154
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
149155
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
156+
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
157+
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
158+
engine_cache_dir (Optional[str]): Directory to store the cached TRT engines
159+
engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default
160+
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
150161
**kwargs: Any,
151162
Returns:
152163
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -224,6 +235,17 @@ def compile(
224235
gm = post_lowering(gm)
225236
logger.debug("Lowered Input graph: " + str(gm.graph))
226237

238+
engine_cache = None
239+
if cache_built_engines or reuse_cached_engines:
240+
assert (
241+
make_refitable
242+
), "Engine caching requires make_refitable to be set to True"
243+
engine_cache = (
244+
custom_engine_cache
245+
if custom_engine_cache is not None
246+
else DiskEngineCache(engine_cache_dir, engine_cache_size)
247+
)
248+
227249
compilation_options = {
228250
"enabled_precisions": (
229251
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
@@ -257,11 +279,15 @@ def compile(
257279
"hardware_compatible": hardware_compatible,
258280
"timing_cache_path": timing_cache_path,
259281
"lazy_engine_init": lazy_engine_init,
282+
"cache_built_engines": cache_built_engines,
283+
"reuse_cached_engines": reuse_cached_engines,
260284
}
261285

262286
settings = CompilationSettings(**compilation_options)
263287
logger.info("Compilation Settings: %s\n", settings)
264-
trt_gm = compile_module(gm, trt_arg_inputs, trt_kwarg_inputs, settings)
288+
trt_gm = compile_module(
289+
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
290+
)
265291
return trt_gm
266292

267293

@@ -270,6 +296,7 @@ def compile_module(
270296
sample_arg_inputs: Sequence[Input],
271297
sample_kwarg_inputs: Optional[dict[Any, Any]] = None,
272298
settings: CompilationSettings = CompilationSettings(),
299+
engine_cache: Optional[BaseEngineCache] = None,
273300
) -> torch.fx.GraphModule:
274301
"""Compile a traced FX module
275302
@@ -280,6 +307,7 @@ def compile_module(
280307
arg_inputs: Inputs to the module
281308
kwarg_inputs: kwargs to the module
282309
settings: Compilation settings
310+
engine_cache: Engine cache instance to store/load compiled engines
283311
Returns:
284312
Compiled FX GraphModule
285313
"""
@@ -436,6 +464,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
436464
submodule_inputs,
437465
settings=settings,
438466
name=name,
467+
engine_cache=engine_cache,
439468
)
440469

441470
trt_modules[name] = trt_module

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,15 @@
3131
DRYRUN = False
3232
HARDWARE_COMPATIBLE = False
3333
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
34-
TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin")
34+
TIMING_CACHE_PATH = os.path.join(
35+
tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin"
36+
)
3537
LAZY_ENGINE_INIT = False
38+
CACHE_BUILT_ENGINES = True
39+
REUSE_CACHED_ENGINES = True
40+
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
41+
ENGINE_CACHE_SIZE = 1073741824
42+
CUSTOM_ENGINE_CACHE = None
3643

3744

3845
def default_device() -> Device:

0 commit comments

Comments
 (0)