Skip to content

Commit 8759736

Browse files
authored
docs: Adding words to the refit and engine caching tutorials (#3141)
Signed-off-by: Naren Dasan <[email protected]>
1 parent 8e75039 commit 8759736

File tree

15 files changed

+273
-81
lines changed

15 files changed

+273
-81
lines changed

core/runtime/Platform.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ Platform::Platform() : _platform{Platform::PlatformEnum::kUNKNOWN} {}
3636
Platform::Platform(Platform::PlatformEnum val) : _platform{val} {}
3737

3838
Platform::Platform(const std::string& platform_str) {
39-
LOG_ERROR("Platform constructor: " << platform_str);
4039
auto name_map = get_name_to_platform_map();
4140
auto it = name_map.find(platform_str);
4241
if (it != name_map.end()) {

docsrc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
sphinx_gallery_conf = {
9494
"examples_dirs": "../examples",
9595
"gallery_dirs": "tutorials/_rendered_examples/",
96+
"ignore_pattern": "utils.py",
9697
}
9798

9899
# Setup the breathe extension

docsrc/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ User Guide
5151
user_guide/using_dla
5252
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
5353
tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq
54+
tutorials/_rendered_examples/dynamo/engine_caching_example
55+
tutorials/_rendered_examples/dynamo/refit_engine_example
5456

5557
Dynamo Frontend
5658
----------------

examples/dynamo/README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@ a number of ways you can leverage this backend to accelerate inference.
1515
* :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights
1616
* :ref:`mutable_torchtrt_module_example`: Compile, use, and modify TensorRT Graph Module with MutableTorchTensorRTModule
1717
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile``
18+
* :ref:`engine_caching_example`: Utilizing engine caching to speed up compilation times
19+
* :ref:`engine_caching_bert_example`: Demonstrating engine caching on BERT

examples/dynamo/engine_caching_bert_example.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
"""
2+
3+
.. _engine_caching_bert_example:
4+
5+
Engine Caching (BERT)
6+
=======================
7+
8+
Small caching example on BERT.
9+
"""
10+
111
import numpy as np
212
import torch
313
import torch_tensorrt

examples/dynamo/engine_caching_example.py

Lines changed: 155 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,38 @@
1+
"""
2+
3+
.. _engine_caching_example:
4+
5+
Engine Caching
6+
=======================
7+
8+
As model sizes increase, the cost of compilation will as well. With AOT methods
9+
like ``torch.dynamo.compile``, this cost is paid upfront. However if the weights
10+
change, the session ends or you are using JIT methods like ``torch.compile``, as
11+
graphs get invalidated they get re-compiled, this cost will get paid repeatedly.
12+
Engine caching is a way to mitigate this cost by saving constructed engines to disk
13+
and re-using them when possible. This tutorial demonstrates how to use engine caching
14+
with TensorRT in PyTorch. Engine caching can significantly speed up subsequent model
15+
compilations reusing previously built TensorRT engines.
16+
17+
We'll explore two approaches:
18+
19+
1. Using torch_tensorrt.dynamo.compile
20+
2. Using torch.compile with the TensorRT backend
21+
22+
The example uses a pre-trained ResNet18 model and shows the
23+
differences between compilation without caching, with caching enabled,
24+
and when reusing cached engines.
25+
"""
26+
127
import os
2-
from typing import Optional
28+
from typing import Dict, Optional
329

430
import numpy as np
531
import torch
632
import torch_tensorrt as torch_trt
733
import torchvision.models as models
834
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
9-
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache
35+
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
1036

1137
np.random.seed(0)
1238
torch.manual_seed(0)
@@ -23,6 +49,80 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
2349
os.remove(path)
2450

2551

52+
# %%
53+
# Engine Caching for JIT Compilation
54+
# ----------------------------------
55+
#
56+
# The primary goal of engine caching is to help speed up JIT workflows. ``torch.compile``
57+
# provides a great deal of flexibility in model construction which makes it a good
58+
# first tool to try when looking to speed up your workflow. However, historically
59+
# the cost of compilation and in particular recompilation has been a barrier to entry
60+
# for many users. If for some reason a subgraph gets invalidated, that graph is reconstructed
61+
# scratch prior to the addition of engine caching. Now as engines are constructed, with ``cache_built_engines=True``,
62+
# engines are saved to disk tied to a hash of their corresponding PyTorch subgraph. If
63+
# in a subsequent compilation, either as part of this session or a new session, the cache will
64+
# pull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude.
65+
# As such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``),
66+
# the engine must be refitable (``make_refittable=True``). See :ref:`refit_engine_example` for more details.
67+
68+
69+
def torch_compile(iterations=3):
70+
times = []
71+
start = torch.cuda.Event(enable_timing=True)
72+
end = torch.cuda.Event(enable_timing=True)
73+
74+
# The 1st iteration is to measure the compilation time without engine caching
75+
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
76+
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
77+
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
78+
for i in range(iterations):
79+
inputs = [torch.rand((100, 3, 224, 224)).to("cuda")]
80+
# remove timing cache and reset dynamo just for engine caching messurement
81+
remove_timing_cache()
82+
torch._dynamo.reset()
83+
84+
if i == 0:
85+
cache_built_engines = False
86+
reuse_cached_engines = False
87+
else:
88+
cache_built_engines = True
89+
reuse_cached_engines = True
90+
91+
start.record()
92+
compiled_model = torch.compile(
93+
model,
94+
backend="tensorrt",
95+
options={
96+
"use_python_runtime": True,
97+
"enabled_precisions": enabled_precisions,
98+
"debug": debug,
99+
"min_block_size": min_block_size,
100+
"make_refitable": True,
101+
"cache_built_engines": cache_built_engines,
102+
"reuse_cached_engines": reuse_cached_engines,
103+
},
104+
)
105+
compiled_model(*inputs) # trigger the compilation
106+
end.record()
107+
torch.cuda.synchronize()
108+
times.append(start.elapsed_time(end))
109+
110+
print("----------------torch_compile----------------")
111+
print("disable engine caching, used:", times[0], "ms")
112+
print("enable engine caching to cache engines, used:", times[1], "ms")
113+
print("enable engine caching to reuse engines, used:", times[2], "ms")
114+
115+
116+
torch_compile()
117+
118+
# %%
119+
# Engine Caching for AOT Compilation
120+
# ----------------------------------
121+
# Similarly to the JIT workflow, AOT workflows can benefit from engine caching.
122+
# As the same architecture or common subgraphs get recompiled, the cache will pull
123+
# previously built engines and refit the weights.
124+
125+
26126
def dynamo_compile(iterations=3):
27127
times = []
28128
start = torch.cuda.Event(enable_timing=True)
@@ -73,42 +173,72 @@ def dynamo_compile(iterations=3):
73173
print("enable engine caching to reuse engines, used:", times[2], "ms")
74174

75175

176+
dynamo_compile()
177+
178+
# %%
76179
# Custom Engine Cache
77-
class MyEngineCache(BaseEngineCache):
180+
# ----------------------
181+
#
182+
# By default, the engine cache is stored in the system's temporary directory. Both the cache directory and
183+
# size limit can be customized by passing ``engine_cache_dir`` and ``engine_cache_size``.
184+
# Users can also define their own engine cache implementation by extending the ``BaseEngineCache`` class.
185+
# This allows for remote or shared caching if so desired.
186+
#
187+
# The custom engine cache should implement the following methods:
188+
# - ``save``: Save the engine blob to the cache.
189+
# - ``load``: Load the engine blob from the cache.
190+
#
191+
# The hash provided by the cache systen is a weight agnostic hash of the originating PyTorch subgraph (post lowering).
192+
# The blob contains a serialized engine, calling spec data, and weight map information in the pickle format
193+
#
194+
# Below is an example of a custom engine cache implementation that implents a ``RAMEngineCache``.
195+
196+
197+
class RAMEngineCache(BaseEngineCache):
78198
def __init__(
79199
self,
80-
engine_cache_dir: str,
81200
) -> None:
82-
self.engine_cache_dir = engine_cache_dir
201+
"""
202+
Constructs a user held engine cache in memory.
203+
"""
204+
self.engine_cache: Dict[str, bytes] = {}
83205

84206
def save(
85207
self,
86208
hash: str,
87209
blob: bytes,
88-
prefix: str = "blob",
89210
):
90-
if not os.path.exists(self.engine_cache_dir):
91-
os.makedirs(self.engine_cache_dir, exist_ok=True)
211+
"""
212+
Insert the engine blob to the cache.
92213
93-
path = os.path.join(
94-
self.engine_cache_dir,
95-
f"{prefix}_{hash}.bin",
96-
)
97-
with open(path, "wb") as f:
98-
f.write(blob)
214+
Args:
215+
hash (str): The hash key to associate with the engine blob.
216+
blob (bytes): The engine blob to be saved.
99217
100-
def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]:
101-
path = os.path.join(self.engine_cache_dir, f"{prefix}_{hash}.bin")
102-
if os.path.exists(path):
103-
with open(path, "rb") as f:
104-
blob = f.read()
105-
return blob
106-
return None
218+
Returns:
219+
None
220+
"""
221+
self.engine_cache[hash] = blob
107222

223+
def load(self, hash: str) -> Optional[bytes]:
224+
"""
225+
Load the engine blob from the cache.
108226
109-
def torch_compile(iterations=3):
227+
Args:
228+
hash (str): The hash key of the engine to load.
229+
230+
Returns:
231+
Optional[bytes]: The engine blob if found, None otherwise.
232+
"""
233+
if hash in self.engine_cache:
234+
return self.engine_cache[hash]
235+
else:
236+
return None
237+
238+
239+
def torch_compile_my_cache(iterations=3):
110240
times = []
111-
engine_cache = MyEngineCache("/tmp/your_dir")
241+
engine_cache = RAMEngineCache()
112242
start = torch.cuda.Event(enable_timing=True)
113243
end = torch.cuda.Event(enable_timing=True)
114244

@@ -141,7 +271,7 @@ def torch_compile(iterations=3):
141271
"make_refitable": True,
142272
"cache_built_engines": cache_built_engines,
143273
"reuse_cached_engines": reuse_cached_engines,
144-
"custom_engine_cache": engine_cache, # use custom engine cache
274+
"custom_engine_cache": engine_cache,
145275
},
146276
)
147277
compiled_model(*inputs) # trigger the compilation
@@ -155,6 +285,4 @@ def torch_compile(iterations=3):
155285
print("enable engine caching to reuse engines, used:", times[2], "ms")
156286

157287

158-
if __name__ == "__main__":
159-
dynamo_compile()
160-
torch_compile()
288+
torch_compile_my_cache()

examples/dynamo/refit_engine_example.py

Lines changed: 63 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
11
"""
22
.. _refit_engine_example:
33
4-
Refit TenorRT Graph Module with Torch-TensorRT
4+
Refitting Torch-TensorRT Programs with New Weights
55
===================================================================
66
7-
We are going to demonstrate how a compiled TensorRT Graph Module can be refitted with updated weights.
8-
9-
In many cases, we frequently update the weights of models, such as applying various LoRA to Stable Diffusion or constant A/B testing of AI products.
10-
That poses challenges for TensorRT inference optimizations, as compiling the TensorRT engines takes significant time, making repetitive compilation highly inefficient.
11-
Torch-TensorRT supports refitting TensorRT graph modules without re-compiling the engine, considerably accelerating the workflow.
7+
Compilation is an expensive operation as it involves many graph transformations, translations
8+
and optimizations applied on the model. In cases were the weights of a model might be updated
9+
occasionally (e.g. inserting LoRA adapters), the large cost of recompilation can make it infeasible
10+
to use TensorRT if the compiled program needed to be built from scratch each time. Torch-TensorRT
11+
provides a PyTorch native mechanism to update the weights of a compiled TensorRT program without
12+
recompiling from scratch through weight refitting.
1213
1314
In this tutorial, we are going to walk through
14-
1. Compiling a PyTorch model to a TensorRT Graph Module
15-
2. Save and load a graph module
16-
3. Refit the graph module
15+
16+
1. Compiling a PyTorch model to a TensorRT Graph Module
17+
2. Save and load a graph module
18+
3. Refit the graph module
19+
20+
This tutorial focuses mostly on the AOT workflow where it is most likely that a user might need to
21+
manually refit a module. In the JIT workflow, weight changes trigger recompilation. As the engine
22+
has previously been built, with an engine cache enabled, Torch-TensorRT can automatically recognize
23+
a previously built engine, trigger refit and short cut recompilation on behalf of the user (see: :ref:`engine_caching_example`).
1724
"""
1825

1926
# %%
@@ -36,10 +43,17 @@
3643

3744

3845
# %%
39-
# Compile the module for the first time and save it.
40-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
41-
42-
model = models.resnet18(pretrained=True).eval().to("cuda")
46+
# Make a Refitable Compilation Program
47+
# ---------------------------------------
48+
#
49+
# The inital step is to compile a module and save it as with a normal. Note that there is an
50+
# additional parameter `make_refitable` that is set to `True`. This parameter is used to
51+
# indicate that the engine being built should support weight refitting later. Engines built without
52+
# these setttings will not be able to be refit.
53+
#
54+
# In this case we are going to compile a ResNet18 model with randomly initialized weights and save it.
55+
56+
model = models.resnet18(pretrained=False).eval().to("cuda")
4357
exp_program = torch.export.export(model, tuple(inputs))
4458
enabled_precisions = {torch.float}
4559
debug = False
@@ -59,16 +73,20 @@
5973
) # Output is a torch.fx.GraphModule
6074

6175
# Save the graph module as an exported program
62-
# This is only supported when use_python_runtime = False
6376
torch_trt.save(trt_gm, "./compiled.ep", inputs=inputs)
6477

6578

6679
# %%
67-
# Refit the module with update model weights
68-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
80+
# Refit the Program with Pretrained Weights
81+
# ------------------------------------------
82+
#
83+
# Random weights are not useful for inference. But now instead of recompiling the model, we can
84+
# refit the model with the pretrained weights. This is done by setting up another PyTorch module
85+
# with the target weights and exporting it as an ExportedProgram. Then the ``refit_module_weights``
86+
# function is used to update the weights of the compiled module with the new weights.
6987

7088
# Create and compile the updated model
71-
model2 = models.resnet18(pretrained=False).eval().to("cuda")
89+
model2 = models.resnet18(pretrained=True).eval().to("cuda")
7290
exp_program2 = torch.export.export(model2, tuple(inputs))
7391

7492

@@ -91,8 +109,32 @@
91109
print("Refit successfully!")
92110

93111
# %%
94-
# Alternative Workflow using Python Runtime
112+
#
113+
# Advanced Usage
95114
# -----------------------------
96-
97-
# Currently python runtime does not support engine serialization. So the refitting will be done in the same runtime.
98-
# This usecase is more useful when you need to switch different weights in the same runtime, such as using Stable Diffusion.
115+
#
116+
# There are a number of settings you can use to control the refit process
117+
#
118+
# Weight Map Cache
119+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
120+
#
121+
# Weight refitting works by matching the weights of the compiled module with the new weights from
122+
# the user supplied ExportedProgram. Since 1:1 name matching from PyTorch to TensorRT is hard to accomplish,
123+
# the only gaurenteed way to match weights at *refit-time* is to pass the new ExportedProgram through the
124+
# early phases of the compilation process to generate near identical weight names. This can be expensive
125+
# and is not always necessary.
126+
#
127+
# To avoid this, **At initial compile**, Torch-TensorRt will attempt to cache a direct mapping from PyTorch
128+
# weights to TensorRT weights. This cache is stored in the compiled module as metadata and can be used
129+
# to speed up refit. If the cache is not present, the refit system will fallback to rebuilding the mapping at
130+
# refit-time. Use of this cache is controlled by the ``use_weight_map_cache`` parameter.
131+
#
132+
# Since the cache uses a heuristic based system for matching PyTorch and TensorRT weights, you may want to verify the refitting. This can be done by setting
133+
# ``verify_output`` to True and providing sample ``arg_inputs`` and ``kwarg_inputs``. When this is done, the refit
134+
# system will run the refitted module and the user supplied module on the same inputs and compare the outputs.
135+
#
136+
# In-Place Refit
137+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
138+
#
139+
# ``in_place`` allows the user to refit the module in place. This is useful when the user wants to update the weights
140+
# of the compiled module without creating a new module.

0 commit comments

Comments
 (0)