Skip to content

Commit dc50ee0

Browse files
committed
Added kwarg to refit
1 parent 604d58c commit dc50ee0

File tree

7 files changed

+132
-28
lines changed

7 files changed

+132
-28
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""
2+
.. _refit_engine_example:
3+
4+
Refit TenorRT Graph Module with Torch-TensorRT
5+
===================================================================
6+
7+
We are going to demonstrate how a compiled TensorRT Graph Module can be refitted with updated weights.
8+
9+
In many cases, we frequently update the weights of models, such as applying various LoRA to Stable Diffusion or constant A/B testing of AI products.
10+
That poses challenges for TensorRT inference optimizations, as compiling the TensorRT engines takes significant time, making repetitive compilation highly inefficient.
11+
Torch-TensorRT supports refitting TensorRT graph modules without re-compiling the engine, considerably accelerating the workflow.
12+
13+
In this tutorial, we are going to walk through
14+
1. Compiling a PyTorch model to a TensorRT Graph Module
15+
2. Save and load a graph module
16+
3. Refit the graph module
17+
"""
18+
19+
# %%
20+
# Standard Workflow
21+
# -----------------------------
22+
23+
# %%
24+
# Imports and model definition
25+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
26+
27+
import numpy as np
28+
import torch
29+
import torch_tensorrt as torch_trt
30+
import torchvision.models as models
31+
from torch_tensorrt.dynamo import refit_module_weights
32+
33+
np.random.seed(0)
34+
torch.manual_seed(0)
35+
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]
36+
37+
38+
# %%
39+
# Compile the module for the first time and save it.
40+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
41+
42+
model = models.resnet18(pretrained=False).eval().to("cuda")
43+
exp_program = torch.export.export(model, tuple(inputs))
44+
enabled_precisions = {torch.float}
45+
debug = False
46+
workspace_size = 20 << 30
47+
min_block_size = 0
48+
use_python_runtime = False
49+
torch_executed_ops = {}
50+
trt_gm = torch_trt.dynamo.compile(
51+
exp_program,
52+
tuple(inputs),
53+
use_python_runtime=use_python_runtime,
54+
enabled_precisions=enabled_precisions,
55+
debug=debug,
56+
min_block_size=min_block_size,
57+
torch_executed_ops=torch_executed_ops,
58+
make_refitable=True,
59+
) # Output is a torch.fx.GraphModule
60+
61+
# Save the graph module as an exported program
62+
# This is only supported when use_python_runtime = False
63+
torch_trt.save(trt_gm, "./compiled.ep", inputs=inputs)
64+
65+
66+
# %%
67+
# Refit the module with update model weights
68+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
69+
70+
# Create and compile the updated model
71+
model2 = models.resnet18(pretrained=True).eval().to("cuda")
72+
exp_program2 = torch.export.export(model2, tuple(inputs))
73+
74+
75+
compiled_trt_ep = torch_trt.load("./compiled.ep")
76+
77+
# This returns a new module with updated weights
78+
new_trt_gm = refit_module_weights(
79+
compiled_module=compiled_trt_ep,
80+
new_weight_module=exp_program2,
81+
arg_inputs=inputs,
82+
)
83+
84+
# Check the output
85+
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs)
86+
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
87+
assert torch.allclose(
88+
expected_output, refitted_output, 1e-2, 1e-2
89+
), "Refit Result is not correct. Refit failed"
90+
91+
print("Refit successfully!")
92+
93+
# %%
94+
# Alternative Workflow using Python Runtime
95+
# -----------------------------
96+
97+
# Currently python runtime does not support engine serialization. So the refitting will be done in the same runtime.
98+
# This usecase is more useful when you need to switch different weights in the same runtime, such as using Stable Diffusion.

examples/dynamo/refit_engine_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
new_trt_gm = refit_module_weights(
7979
compiled_module=compiled_trt_ep,
8080
new_weight_module=exp_program2,
81-
inputs=inputs,
81+
arg_inputs=inputs,
8282
)
8383

8484
# Check the output

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import collections.abc
44
import copy
55
import logging
6-
from typing import Any, Sequence, Tuple
6+
from typing import Any, Optional, Sequence, Tuple
77

88
import numpy as np
99
import tensorrt as trt
@@ -36,7 +36,6 @@
3636
from torch_tensorrt.dynamo.utils import (
3737
check_output,
3838
get_torch_inputs,
39-
prepare_inputs,
4039
set_log_level,
4140
to_torch_device,
4241
to_torch_tensorrt_device,
@@ -146,7 +145,8 @@ def _refit_single_trt_engine_with_gm(
146145
def refit_module_weights(
147146
compiled_module: torch.fx.GraphModule | ExportedProgram,
148147
new_weight_module: ExportedProgram,
149-
inputs: Tuple[Any, ...],
148+
arg_inputs: Optional[Tuple[Any, ...]] = None,
149+
kwarg_inputs: Optional[dict[str, Any]] = None,
150150
verify_output: bool = False,
151151
) -> torch.fx.GraphModule:
152152
"""
@@ -208,27 +208,29 @@ def refit_module_weights(
208208
if settings.debug:
209209
set_log_level(logger.parent, logging.DEBUG)
210210

211-
if not isinstance(inputs, collections.abc.Sequence):
212-
inputs = [inputs]
213-
214-
# Prepare torch_trt inputs
215-
inputs = prepare_inputs(inputs)
216211
device = to_torch_tensorrt_device(settings.device)
217-
torch_inputs = get_torch_inputs(inputs, device)
212+
if arg_inputs:
213+
if not isinstance(arg_inputs, collections.abc.Sequence):
214+
# Prepare torch_trt inputs
215+
arg_inputs = [arg_inputs]
216+
torch_inputs = get_torch_inputs(arg_inputs, device)
217+
218+
if kwarg_inputs:
219+
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
218220
runtime = trt.Runtime(TRT_LOGGER)
219221
if not isinstance(new_weight_module, ExportedProgram):
220222
raise AssertionError(
221223
f"Input graph should be an ExportedProgram but got type {type(new_weight_module)}"
222224
)
223-
new_weight_module = pre_export_lowering(new_weight_module, torch_inputs)
225+
new_weight_module = pre_export_lowering(new_weight_module)
224226
new_weight_module = new_weight_module.run_decompositions(
225227
get_decompositions(settings.enable_experimental_decompositions)
226228
)
227229
new_gm = new_weight_module.module()
228230
logger.debug("Input graph: " + str(new_gm.graph))
229231
# Apply lowering on the graph module
230232

231-
new_gm = post_lowering(new_gm, torch_inputs)
233+
new_gm = post_lowering(new_gm)
232234

233235
logger.info("Compilation Settings: %s\n", settings)
234236

@@ -354,11 +356,12 @@ def refit_module_weights(
354356
refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info))
355357
setattr(compiled_module, f"{name}_engine", refitted_engine)
356358

357-
if verify_output:
359+
if verify_output and arg_inputs is not None:
358360
if check_output(
359361
new_module=new_gm,
360362
refitted_module=compiled_module,
361-
inputs=torch_inputs,
363+
arg_inputs=torch_inputs,
364+
kwarg_inputs=torch_kwarg_inputs,
362365
):
363366
logger.info("Refitting Succeed!")
364367
else:

py/torch_tensorrt/dynamo/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,12 @@ def function_wrapper(*args: Any, **kwargs: Any) -> Any:
397397
def check_output(
398398
new_module: torch.fx.GraphModule,
399399
refitted_module: torch.fx.GraphModule,
400-
inputs: tuple[Any, ...],
400+
arg_inputs: Any,
401+
kwarg_inputs: Any = None,
401402
) -> bool:
402-
old_outputs, new_outputs = refitted_module(*inputs), new_module(*inputs)
403+
old_outputs, new_outputs = refitted_module(*arg_inputs), new_module(
404+
*arg_inputs, **kwarg_inputs
405+
)
403406
for old_output, new_output in zip(old_outputs, new_outputs):
404407
if isinstance(old_output, torch.Tensor) and isinstance(
405408
new_outputs, torch.Tensor

tests/py/dynamo/models/test_dtype_support.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def forward(self, x):
3636
exp_mod = torch.export.export(mod, (in_tensor,))
3737
trt_mod = torch_tensorrt.dynamo.compile(
3838
exp_mod,
39-
arg_inputs=[in_tensor],
39+
inputs=[in_tensor],
4040
pass_through_build_failures=True,
4141
truncate_double=True,
4242
min_block_size=1,
@@ -74,7 +74,7 @@ def forward(self, x):
7474
exp_mod = torch.export.export(mod, (in_tensor,))
7575
trt_mod = torch_tensorrt.dynamo.compile(
7676
exp_mod,
77-
arg_inputs=[in_tensor],
77+
inputs=[in_tensor],
7878
pass_through_build_failures=True,
7979
truncate_double=True,
8080
min_block_size=1,
@@ -118,7 +118,7 @@ def forward(self, x):
118118
exp_mod = torch.export.export(mod, (in_tensor,))
119119
trt_mod = torch_tensorrt.dynamo.compile(
120120
exp_mod,
121-
arg_inputs=[in_tensor],
121+
inputs=[in_tensor],
122122
pass_through_build_failures=True,
123123
truncate_double=False,
124124
min_block_size=1,
@@ -157,7 +157,7 @@ def forward(self, x):
157157
exp_mod = torch.export.export(mod, (in_tensor,))
158158
trt_mod = torch_tensorrt.dynamo.compile(
159159
exp_mod,
160-
arg_inputs=[in_tensor],
160+
inputs=[in_tensor],
161161
pass_through_build_failures=True,
162162
truncate_double=False,
163163
min_block_size=1,
@@ -201,7 +201,7 @@ def forward(self, x):
201201
exp_mod = torch.export.export(mod, (in_tensor,))
202202
trt_mod = torch_tensorrt.dynamo.compile(
203203
exp_mod,
204-
arg_inputs=[in_tensor],
204+
inputs=[in_tensor],
205205
pass_through_build_failures=True,
206206
enabled_precisions={torch.float, torch.bfloat16, torch.half},
207207
min_block_size=1,
@@ -239,7 +239,7 @@ def forward(self, x):
239239
exp_mod = torch.export.export(mod, (in_tensor,))
240240
trt_mod = torch_tensorrt.dynamo.compile(
241241
exp_mod,
242-
arg_inputs=[in_tensor],
242+
inputs=[in_tensor],
243243
pass_through_build_failures=True,
244244
enabled_precisions={torch.float, torch.bfloat16, torch.half},
245245
min_block_size=1,

tests/py/dynamo/models/test_model_refit.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_refit_one_engine():
108108
new_trt_gm = refit_module_weights(
109109
compiled_module=trt_gm,
110110
new_weight_module=exp_program2,
111-
inputs=inputs,
111+
arg_inputs=inputs,
112112
)
113113

114114
# Check the output
@@ -154,7 +154,7 @@ def test_refit_one_engine_bert():
154154
new_trt_gm = refit_module_weights(
155155
compiled_module=trt_gm,
156156
new_weight_module=exp_program2,
157-
inputs=inputs,
157+
arg_inputs=inputs,
158158
)
159159

160160
# Check the output
@@ -203,7 +203,7 @@ def test_refit_one_engine_inline_runtime():
203203
new_trt_gm = refit_module_weights(
204204
compiled_module=trt_gm,
205205
new_weight_module=exp_program2,
206-
inputs=inputs,
206+
arg_inputs=inputs,
207207
)
208208

209209
# Check the output
@@ -247,7 +247,7 @@ def test_refit_one_engine_python_runtime():
247247
new_trt_gm = refit_module_weights(
248248
compiled_module=trt_gm,
249249
new_weight_module=exp_program2,
250-
inputs=inputs,
250+
arg_inputs=inputs,
251251
)
252252

253253
# Check the output
@@ -313,7 +313,7 @@ def forward(self, x):
313313
new_trt_gm = refit_module_weights(
314314
compiled_module=trt_gm,
315315
new_weight_module=exp_program2,
316-
inputs=inputs,
316+
arg_inputs=inputs,
317317
)
318318

319319
# Check the output

tests/py/dynamo/models/test_models_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def calibrate_loop(model):
222222
exp_program = torch.export.export(model, (input_tensor,))
223223
trt_model = torchtrt.dynamo.compile(
224224
exp_program,
225-
arg_inputs=[input_tensor],
225+
inputs=[input_tensor],
226226
enabled_precisions={torch.float8_e4m3fn},
227227
min_block_size=1,
228228
debug=True,

0 commit comments

Comments
 (0)