You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We are going to demonstrate how a compiled TensorRT Graph Module can be refitted with updated weights.
8
-
9
-
In many cases, we frequently update the weights of models, such as applying various LoRA to Stable Diffusion or constant A/B testing of AI products.
10
-
That poses challenges for TensorRT inference optimizations, as compiling the TensorRT engines takes significant time, making repetitive compilation highly inefficient.
11
-
Torch-TensorRT supports refitting TensorRT graph modules without re-compiling the engine, considerably accelerating the workflow.
7
+
Compilation is an expensive operation as it involves many graph transformations, translations
8
+
and optimizations applied on the model. In cases were the weights of a model might be updated
9
+
occasionally (e.g. inserting LoRA adapters), the large cost of recompilation can make it infeasible
10
+
to use TensorRT if the compiled program needed to be built from scratch each time. Torch-TensorRT
11
+
provides a PyTorch native mechanism to update the weights of a compiled TensorRT program without
12
+
recompiling from scratch through weight refitting.
12
13
13
14
In this tutorial, we are going to walk through
14
-
1. Compiling a PyTorch model to a TensorRT Graph Module
15
-
2. Save and load a graph module
16
-
3. Refit the graph module
15
+
16
+
1. Compiling a PyTorch model to a TensorRT Graph Module
17
+
2. Save and load a graph module
18
+
3. Refit the graph module
19
+
20
+
This tutorial focuses mostly on the AOT workflow where it is most likely that a user might need to
21
+
manually refit a module. In the JIT workflow, weight changes trigger recompilation. As the engine
22
+
has previously been built, with an engine cache enabled, Torch-TensorRT can automatically recognize
23
+
a previously built engine, trigger refit and short cut recompilation on behalf of the user (see: :ref:`engine_caching_example`).
17
24
"""
18
25
19
26
# %%
@@ -36,10 +43,17 @@
36
43
37
44
38
45
# %%
39
-
# Compile the module for the first time and save it.
# Currently python runtime does not support engine serialization. So the refitting will be done in the same runtime.
98
-
# This usecase is more useful when you need to switch different weights in the same runtime, such as using Stable Diffusion.
115
+
#
116
+
# There are a number of settings you can use to control the refit process
117
+
#
118
+
# Weight Map Cache
119
+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
120
+
#
121
+
# Weight refitting works by matching the weights of the compiled module with the new weights from
122
+
# the user supplied ExportedProgram. Since 1:1 name matching from PyTorch to TensorRT is hard to accomplish,
123
+
# the only gaurenteed way to match weights at *refit-time* is to pass the new ExportedProgram through the
124
+
# early phases of the compilation process to generate near identical weight names. This can be expensive
125
+
# and is not always necessary.
126
+
#
127
+
# To avoid this, **At initial compile**, Torch-TensorRt will attempt to cache a direct mapping from PyTorch
128
+
# weights to TensorRT weights. This cache is stored in the compiled module as metadata and can be used
129
+
# to speed up refit. If the cache is not present, the refit system will fallback to rebuilding the mapping at
130
+
# refit-time. Use of this cache is controlled by the ``use_weight_map_cache`` parameter.
131
+
#
132
+
# Since the cache uses a heuristic based system for matching PyTorch and TensorRT weights, you may want to verify the refitting. This can be done by setting
133
+
# ``verify_output`` to True and providing sample ``arg_inputs`` and ``kwarg_inputs``. When this is done, the refit
134
+
# system will run the refitted module and the user supplied module on the same inputs and compare the outputs.
135
+
#
136
+
# In-Place Refit
137
+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
138
+
#
139
+
# ``in_place`` allows the user to refit the module in place. This is useful when the user wants to update the weights
140
+
# of the compiled module without creating a new module.
0 commit comments