Skip to content

Commit 8abb537

Browse files
committed
Implemented basic pipeline for Refitting (#2886)
1 parent 7e4da0d commit 8abb537

File tree

19 files changed

+961
-50
lines changed

19 files changed

+961
-50
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,16 @@ TRTEngine::TRTEngine(
3333
const RTDevice& cuda_device,
3434
const std::vector<std::string>& _in_binding_names,
3535
const std::vector<std::string>& _out_binding_names,
36-
bool hardware_compatible)
36+
bool hardware_compatible,
37+
const std::string& serialized_metadata)
3738
: TRTEngine(
3839
"deserialized_trt",
3940
serialized_engine,
4041
cuda_device,
4142
_in_binding_names,
4243
_out_binding_names,
43-
hardware_compatible) {}
44+
hardware_compatible,
45+
serialized_metadata) {}
4446

4547
TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
4648
: TRTEngine(
@@ -49,17 +51,19 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
4951
RTDevice(serialized_info[DEVICE_IDX]),
5052
split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM),
5153
split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM),
52-
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX]))) {}
54+
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
55+
serialized_info[SERIALIZED_METADATA_IDX]) {}
5356

5457
TRTEngine::TRTEngine(
5558
const std::string& mod_name,
5659
const std::string& serialized_engine,
5760
const RTDevice& cuda_device,
5861
const std::vector<std::string>& _in_binding_names,
5962
const std::vector<std::string>& _out_binding_names,
60-
bool hardware_compatible) {
63+
bool hardware_compatible,
64+
const std::string& serialized_metadata) {
6165
this->hardware_compatible = hardware_compatible;
62-
66+
this->serialized_metadata = serialized_metadata;
6367
auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible);
6468
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");
6569
device_info = most_compatible_device.value();

core/runtime/TRTEngine.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,26 @@ struct TRTEngine : torch::CustomClassHolder {
3535
std::vector<std::string> out_binding_names = {}; // ITO: PYT IDX
3636

3737
bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode
38+
std::string serialized_metadata; // This is a base64 encoded pkl object used to store metadata such as settings used
39+
// in compilation
3840

3941
~TRTEngine();
4042
TRTEngine(
4143
const std::string& serialized_engine,
4244
const RTDevice& cuda_device,
4345
const std::vector<std::string>& in_binding_names,
4446
const std::vector<std::string>& out_binding_names,
45-
bool hardware_compatible = false);
47+
bool hardware_compatible = false,
48+
const std::string& serialized_metadata = "");
4649
TRTEngine(std::vector<std::string> serialized_info);
4750
TRTEngine(
4851
const std::string& mod_name,
4952
const std::string& serialized_engine,
5053
const RTDevice& cuda_device,
5154
const std::vector<std::string>& in_binding_names,
5255
const std::vector<std::string>& out_binding_names,
53-
bool hardware_compatible = false);
56+
bool hardware_compatible = false,
57+
const std::string& serialized_metadata = "");
5458
TRTEngine& operator=(const TRTEngine& other);
5559
std::string to_str() const;
5660
static void verify_serialization_fmt(const std::vector<std::string>& serialized_info);

core/runtime/register_jit_hooks.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
102102
serialize_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(self->in_binding_names);
103103
serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names);
104104
serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0";
105-
105+
serialize_info[SERIALIZED_METADATA_IDX] = self->serialized_metadata;
106106
LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled"));
107107

108108
return serialize_info;
@@ -127,6 +127,15 @@ TORCH_LIBRARY(tensorrt, m) {
127127
});
128128
m.def(
129129
"get_logging_level", []() -> int64_t { return int64_t(util::logging::get_logger().get_reportable_log_level()); });
130+
m.def("ABI_TARGET_IDX", []() -> int64_t { return ABI_TARGET_IDX; });
131+
m.def("NAME_IDX", []() -> int64_t { return NAME_IDX; });
132+
m.def("DEVICE_IDX", []() -> int64_t { return DEVICE_IDX; });
133+
m.def("ENGINE_IDX", []() -> int64_t { return ENGINE_IDX; });
134+
m.def("INPUT_BINDING_NAMES_IDX", []() -> int64_t { return INPUT_BINDING_NAMES_IDX; });
135+
m.def("OUTPUT_BINDING_NAMES_IDX", []() -> int64_t { return OUTPUT_BINDING_NAMES_IDX; });
136+
m.def("HW_COMPATIBLE_IDX", []() -> int64_t { return HW_COMPATIBLE_IDX; });
137+
m.def("SERIALIZED_METADATA_IDX", []() -> int64_t { return SERIALIZED_METADATA_IDX; });
138+
m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; });
130139
}
131140

132141
} // namespace

core/runtime/runtime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ typedef enum {
2525
INPUT_BINDING_NAMES_IDX,
2626
OUTPUT_BINDING_NAMES_IDX,
2727
HW_COMPATIBLE_IDX,
28+
SERIALIZED_METADATA_IDX,
2829
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
2930
} SerializedInfoIndex;
3031

docsrc/py_api/dynamo.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Functions
2424

2525
.. autofunction:: convert_module_to_trt_engine
2626

27-
27+
.. autofunction:: refit_module_weights
2828

2929
Classes
3030
--------

examples/dynamo/README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ a number of ways you can leverage this backend to accelerate inference.
1111
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
1212
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
1313
* :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines
14+
* :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights
1415
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile``
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""
2+
.. _refit_engine_example:
3+
4+
Refit TenorRT Graph Module with Torch-TensorRT
5+
===================================================================
6+
7+
We are going to demonstrate how a compiled TensorRT Graph Module can be refitted with updated weights.
8+
9+
In many cases, we frequently update the weights of models, such as applying various LoRA to Stable Diffusion or constant A/B testing of AI products.
10+
That poses challenges for TensorRT inference optimizations, as compiling the TensorRT engines takes significant time, making repetitive compilation highly inefficient.
11+
Torch-TensorRT supports refitting TensorRT graph modules without re-compiling the engine, considerably accelerating the workflow.
12+
13+
In this tutorial, we are going to walk through
14+
1. Compiling a PyTorch model to a TensorRT Graph Module
15+
2. Save and load a graph module
16+
3. Refit the graph module
17+
"""
18+
19+
# %%
20+
# Standard Workflow
21+
# -----------------------------
22+
23+
# %%
24+
# Imports and model definition
25+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
26+
27+
import numpy as np
28+
import torch
29+
import torch_tensorrt as torch_trt
30+
import torchvision.models as models
31+
from torch_tensorrt.dynamo import refit_module_weights
32+
33+
np.random.seed(0)
34+
torch.manual_seed(0)
35+
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]
36+
37+
38+
# %%
39+
# Compile the module for the first time and save it.
40+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
41+
42+
model = models.resnet18(pretrained=False).eval().to("cuda")
43+
exp_program = torch.export.export(model, tuple(inputs))
44+
enabled_precisions = {torch.float}
45+
debug = False
46+
workspace_size = 20 << 30
47+
min_block_size = 0
48+
use_python_runtime = False
49+
torch_executed_ops = {}
50+
trt_gm = torch_trt.dynamo.compile(
51+
exp_program,
52+
tuple(inputs),
53+
use_python_runtime=use_python_runtime,
54+
enabled_precisions=enabled_precisions,
55+
debug=debug,
56+
min_block_size=min_block_size,
57+
torch_executed_ops=torch_executed_ops,
58+
make_refitable=True,
59+
) # Output is a torch.fx.GraphModule
60+
61+
# Save the graph module as an exported program
62+
# This is only supported when use_python_runtime = False
63+
torch_trt.save(trt_gm, "./compiled.ep", inputs=inputs)
64+
65+
66+
# %%
67+
# Refit the module with update model weights
68+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
69+
70+
# Create and compile the updated model
71+
model2 = models.resnet18(pretrained=True).eval().to("cuda")
72+
exp_program2 = torch.export.export(model2, tuple(inputs))
73+
74+
75+
compiled_trt_ep = torch_trt.load("./compiled.ep")
76+
77+
# This returns a new module with updated weights
78+
new_trt_gm = refit_module_weights(
79+
compiled_module=compiled_trt_ep,
80+
new_weight_module=exp_program2,
81+
inputs=inputs,
82+
)
83+
84+
# Check the output
85+
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs)
86+
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
87+
assert torch.allclose(
88+
expected_output, refitted_output, 1e-2, 1e-2
89+
), "Refit Result is not correct. Refit failed"
90+
91+
print("Refit successfully!")
92+
93+
# %%
94+
# Alterative Workflow using Python Runtime
95+
# -----------------------------
96+
97+
# Currently python runtime does not support engine serialization. So the refitting will be done in the same runtime.
98+
# This usecase is more useful when you need to switch different weights in the same runtime, such as using Stable Diffusion.

py/torch_tensorrt/_compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,9 @@ def convert_method_to_trt_engine(
351351
torchtrt_inputs = prepare_inputs(inputs)
352352
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
353353

354-
return dynamo_convert_module_to_trt_engine( # type: ignore[no-any-return]
354+
return dynamo_convert_module_to_trt_engine(
355355
exp_program,
356-
inputs=inputs,
356+
inputs=tuple(inputs),
357357
enabled_precisions=enabled_precisions_set,
358358
**kwargs,
359359
)

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
1010
from ._compiler import compile, convert_module_to_trt_engine
1111
from ._exporter import export
12+
from ._refit import refit_module_weights
1213
from ._settings import CompilationSettings
1314
from ._SourceIR import SourceIR
1415
from ._tracer import trace

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def compile(
5959
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
6060
) = _defaults.ENABLED_PRECISIONS,
6161
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
62-
refit: bool = _defaults.REFIT,
62+
make_refitable: bool = _defaults.MAKE_REFITABLE,
6363
debug: bool = _defaults.DEBUG,
6464
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
6565
workspace_size: int = _defaults.WORKSPACE_SIZE,
@@ -162,6 +162,18 @@ def compile(
162162
)
163163
if kwarg_inputs is None:
164164
kwarg_inputs = {}
165+
166+
if "refit" in kwargs.keys():
167+
warnings.warn(
168+
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
169+
DeprecationWarning,
170+
stacklevel=2,
171+
)
172+
if make_refitable:
173+
raise ValueError("Use flag make_refitable only. Flag refit is deprecated.")
174+
else:
175+
make_refitable = kwargs["refit"]
176+
165177
engine_capability = EngineCapability._from(engine_capability)
166178

167179
if torch_executed_modules is not None and torch_executed_modules:
@@ -229,7 +241,7 @@ def compile(
229241
"require_full_compilation": require_full_compilation,
230242
"disable_tf32": disable_tf32,
231243
"sparse_weights": sparse_weights,
232-
"refit": refit,
244+
"make_refitable": make_refitable,
233245
"engine_capability": engine_capability,
234246
"dla_sram_size": dla_sram_size,
235247
"dla_local_dram_size": dla_local_dram_size,
@@ -497,7 +509,7 @@ def convert_module_to_trt_engine(
497509
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
498510
disable_tf32: bool = _defaults.DISABLE_TF32,
499511
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
500-
refit: bool = _defaults.REFIT,
512+
make_refitable: bool = _defaults.MAKE_REFITABLE,
501513
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
502514
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
503515
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
@@ -580,6 +592,12 @@ def convert_module_to_trt_engine(
580592
DeprecationWarning,
581593
stacklevel=2,
582594
)
595+
if "refit" in kwargs.keys():
596+
warnings.warn(
597+
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
598+
DeprecationWarning,
599+
stacklevel=2,
600+
)
583601

584602
input_list = list(inputs) if inputs is not None else []
585603
torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set()
@@ -608,7 +626,7 @@ def convert_module_to_trt_engine(
608626
"require_full_compilation": require_full_compilation,
609627
"disable_tf32": disable_tf32,
610628
"sparse_weights": sparse_weights,
611-
"refit": refit,
629+
"make_refitable": make_refitable,
612630
"engine_capability": engine_capability,
613631
"num_avg_timing_iters": num_avg_timing_iters,
614632
"dla_sram_size": dla_sram_size,

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
USE_PYTHON_RUNTIME = False
2727
USE_FAST_PARTITIONER = True
2828
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
29-
REFIT = False
29+
MAKE_REFITABLE = False
3030
REQUIRE_FULL_COMPILATION = False
3131
DRYRUN = False
3232
HARDWARE_COMPATIBLE = False

0 commit comments

Comments
 (0)