Skip to content

Commit 92bf700

Browse files
authored
feat: Support weight streaming (#3111)
1 parent dad195b commit 92bf700

File tree

20 files changed

+680
-6
lines changed

20 files changed

+680
-6
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@ TRTEngine::TRTEngine(
8989
cuda_engine = make_trt(rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size()));
9090
TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine");
9191

92+
if (get_streamable_device_memory_budget() > 0) {
93+
int64_t budget_bytes = get_automatic_device_memory_budget();
94+
LOG_DEBUG("Weight streaming budget set to " << budget_bytes << "B");
95+
cuda_engine->setWeightStreamingBudgetV2(budget_bytes);
96+
}
97+
9298
exec_ctx = make_trt(cuda_engine->createExecutionContext());
9399
TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context");
94100

@@ -258,6 +264,38 @@ void TRTEngine::set_profiling_paths() {
258264
cuda_graph_debug_path = std::filesystem::path{profile_path_prefix + "/" + name + "_cudagraph.dot"}.string();
259265
}
260266

267+
int64_t TRTEngine::get_device_memory_budget() {
268+
return cuda_engine->getWeightStreamingBudgetV2();
269+
}
270+
271+
bool TRTEngine::set_device_memory_budget(int64_t budget) {
272+
// Recreating the context because weight streaming budget cannot be modified while there are active context.
273+
if (exec_ctx.get() != nullptr) {
274+
exec_ctx.reset();
275+
}
276+
if (profile_execution) {
277+
trt_engine_profiler.reset();
278+
}
279+
bool result = cuda_engine->setWeightStreamingBudgetV2(budget);
280+
exec_ctx = make_trt(cuda_engine->createExecutionContext());
281+
TORCHTRT_CHECK(
282+
(exec_ctx.get() != nullptr),
283+
"Unable to recreate TensorRT execution context after setting new device memory budget");
284+
if (profile_execution) {
285+
enable_profiling();
286+
}
287+
return result;
288+
}
289+
290+
// Returns 0 if BuilderFlag::kWEIGHT_STREAMING is unset during engine building.
291+
int64_t TRTEngine::get_streamable_device_memory_budget() {
292+
return cuda_engine->getStreamableWeightsSize();
293+
}
294+
295+
int64_t TRTEngine::get_automatic_device_memory_budget() {
296+
return cuda_engine->getWeightStreamingAutomaticBudget();
297+
}
298+
261299
std::string TRTEngine::to_str() const {
262300
// clang-format off
263301
std::stringstream ss;

core/runtime/TRTEngine.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ struct TRTEngine : torch::CustomClassHolder {
7171
std::string get_engine_layer_info();
7272
void dump_engine_layer_info_to_file(const std::string& path);
7373
void dump_engine_layer_info();
74+
int64_t get_device_memory_budget();
75+
bool set_device_memory_budget(int64_t budget);
76+
int64_t get_streamable_device_memory_budget();
77+
int64_t get_automatic_device_memory_budget();
7478
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
7579
static const char BINDING_DELIM = '%';
7680

core/runtime/register_jit_hooks.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
8686
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
8787
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
8888
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
89+
.def_property(
90+
"device_memory_budget",
91+
&TRTEngine::get_device_memory_budget,
92+
&TRTEngine::set_device_memory_budget)
93+
.def_property("streamable_device_memory_budget", &TRTEngine::get_streamable_device_memory_budget)
94+
.def_property("automatic_device_memory_budget", &TRTEngine::get_automatic_device_memory_budget)
8995
.def_pickle(
9096
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
9197
// Serialize TensorRT engine

docsrc/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ Tutorials
6666
* :ref:`converter_overloading`
6767
* :ref:`custom_kernel_plugins`
6868
* :ref:`mutable_torchtrt_module_example`
69+
* :ref:`weight_streaming_example`
6970

7071
.. toctree::
7172
:caption: Tutorials
@@ -82,6 +83,7 @@ Tutorials
8283
tutorials/_rendered_examples/dynamo/converter_overloading
8384
tutorials/_rendered_examples/dynamo/custom_kernel_plugins
8485
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
86+
tutorials/_rendered_examples/dynamo/weight_streaming_example
8587

8688
Dynamo Frontend
8789
----------------
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""
2+
.. _weight_streaming_example:
3+
4+
Weight Streaming
5+
=======================
6+
7+
Weight streaming in TensorRT is a powerful feature designed to overcome GPU memory limitations
8+
when working with large models. It enables running models larger than available GPU memory
9+
by streaming weight data from host (CPU) memory to GPU memory during inference.
10+
11+
Streaming larger amounts of memory will likely result in lower performance. But if
12+
streaming weights allows the user to run larger batch sizes and it can lead to higher throughput.
13+
This increased throughput can sometimes outweigh the slowdown caused by streaming weights.
14+
The optimal amount of memory to stream varies depending on the specific model and hardware.
15+
Experimenting with different memory limits can help find the best balance between streaming
16+
overhead and batch size benefits.
17+
18+
This example uses a pre-trained Llama-2 model and show how to use weight streaming feature with
19+
Torch-TensorRT.
20+
1. compile option - build trt engine with weight streaming feature
21+
2. runtime api - weight streaming budget control by context manager
22+
"""
23+
24+
# %%
25+
# Imports and Model Definition
26+
# ----------------------------------
27+
28+
import copy
29+
import timeit
30+
31+
import numpy as np
32+
import torch
33+
import torch_tensorrt
34+
from transformers import AutoModelForCausalLM
35+
from utils import export_llm
36+
37+
38+
def time_generate(model, inputs, output_seq_length, iterations=10):
39+
"""
40+
Measure the time for generating a sentence over certain number of iterations
41+
"""
42+
# We only support single input (B x seq_len) for LLMs now
43+
input_seq = inputs[0]
44+
with torch.no_grad():
45+
timings = []
46+
for _ in range(iterations):
47+
start_time = timeit.default_timer()
48+
inputs_copy = copy.copy(input_seq)
49+
# Greedy decoding of the model. This generates up to max_tokens.
50+
while inputs_copy.shape[1] <= output_seq_length:
51+
outputs = model(inputs_copy)
52+
logits = outputs.logits
53+
next_token_logits = logits[:, -1, :]
54+
next_tokens = torch.argmax(next_token_logits, dim=-1)
55+
inputs_copy = torch.cat([inputs_copy, next_tokens[:, None]], dim=-1)
56+
torch.cuda.synchronize()
57+
end_time = timeit.default_timer()
58+
timings.append(end_time - start_time)
59+
60+
times = np.array(timings)
61+
time_mean_ms = np.mean(times) * 1000
62+
63+
return time_mean_ms
64+
65+
66+
# Load the LLaMA-2 model
67+
DEVICE = torch.device("cuda:0")
68+
llama_path = "meta-llama/Llama-2-7b-chat-hf"
69+
with torch.no_grad():
70+
model = AutoModelForCausalLM.from_pretrained(
71+
llama_path, use_cache=False, attn_implementation="eager"
72+
).eval()
73+
74+
# Set input and output sequence lengths
75+
isl = 128
76+
osl = 256
77+
78+
# Create random input tensors
79+
input_tensors = [torch.randint(0, 5, (1, isl), dtype=torch.int64).cuda()]
80+
# Convert the model to half precision (FP16)
81+
model = model.half()
82+
# Exports the LLM model into an ExportedProgram with dynamic shapes.
83+
llama2_ep = export_llm(model, input_tensors[0], max_seq_len=osl)
84+
85+
# %%
86+
# Compiler option
87+
# ----------------------------------
88+
#
89+
# enable_weight_streaming=True option and use_explicit_typing=True are required to build
90+
# the engine with weight streaming feature. use_explicit_typing=True option creates a
91+
# `strongly typed network <https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#strongly-typed-networks>`_ and only float32 precision is allowed in enabled_precisions option
92+
#
93+
94+
# Create a TensorRT-compiled model
95+
trt_model = torch_tensorrt.dynamo.compile(
96+
llama2_ep,
97+
inputs=input_tensors,
98+
enabled_precisions={torch.float32},
99+
truncate_double=True,
100+
device=DEVICE,
101+
use_explicit_typing=True,
102+
enable_weight_streaming=True,
103+
)
104+
105+
# Warm up for 3 iterations
106+
_ = time_generate(trt_model, input_tensors, osl, 3)
107+
108+
# %%
109+
# Running with automatic budget size
110+
# ----------------------------------
111+
#
112+
# Once you specify the enable_weight_streaming compile option, automatic budget size is configured.
113+
# This automatic size may not always provide the optimal solution because the automatically determined
114+
# budget lacks insight into the user's specific memory constraints and usage patterns
115+
116+
# Weight streaming context to get current weight budget information
117+
weight_streaming_ctx = torch_tensorrt.runtime.weight_streaming(trt_model)
118+
# Measure the mean latency of the model with weight streaming
119+
mean_latency = time_generate(trt_model, input_tensors, osl, 1)
120+
# Calculate the percentage of current weight budget used
121+
weight_budget_pct = (
122+
weight_streaming_ctx.device_budget / weight_streaming_ctx.total_device_budget * 100
123+
)
124+
print(
125+
f"Set weight streaming budget as {weight_budget_pct}%. {weight_streaming_ctx.device_budget} bytes out of {weight_streaming_ctx.total_device_budget}. mean latency = {mean_latency} ms"
126+
)
127+
128+
# %%
129+
# Running with weight streaming context manager
130+
# ----------------------------------
131+
#
132+
# Weight streaming budget can be limited by using weight streaming context manager.
133+
# The permissible range for the budget size is from 0 to ctx.total_device_budget.
134+
# 0 means maximum memory savings occur by using minimum amounts of memory. Value
135+
# equal to ctx.total_device_budget will disable weight streaming.
136+
# If multiple trt engines are created, budgets are distributed proportionally
137+
138+
# Use a context manager for weight streaming
139+
with torch_tensorrt.runtime.weight_streaming(trt_model) as weight_streaming_ctx:
140+
# Get the total size of streamable weights in the engine
141+
streamable_budget = weight_streaming_ctx.total_device_budget
142+
143+
# Scenario 1: Automatic weight streaming budget
144+
# Get the automatically determined weight streaming budget
145+
requested_budget = weight_streaming_ctx.get_automatic_weight_streaming_budget()
146+
# Set the device budget to the automatically determined value
147+
weight_streaming_ctx.device_budget = requested_budget
148+
# Measure the mean latency with automatic budget
149+
mean_latency = time_generate(trt_model, input_tensors, osl, 1)
150+
# Calculate the percentage of the weight budget used
151+
weight_budget_pct = (
152+
weight_streaming_ctx.device_budget
153+
/ weight_streaming_ctx.total_device_budget
154+
* 100
155+
)
156+
print(
157+
f"Set auto weight streaming budget as {weight_budget_pct}%. {weight_streaming_ctx.device_budget} bytes out of {weight_streaming_ctx.total_device_budget}. mean latency = {mean_latency} ms"
158+
)
159+
160+
# Scenario 2: Manual 10% weight streaming budget
161+
# Set the budget to 10% of the total streamable weights
162+
requested_budget = int(streamable_budget * 0.1)
163+
weight_streaming_ctx.device_budget = requested_budget
164+
# Measure the mean latency with 10% budget
165+
mean_latency = time_generate(trt_model, input_tensors, osl, 1)
166+
# Calculate the percentage of the weight budget used
167+
weight_budget_pct = (
168+
weight_streaming_ctx.device_budget
169+
/ weight_streaming_ctx.total_device_budget
170+
* 100
171+
)
172+
print(
173+
f"Set weight streaming budget as {weight_budget_pct}%. {weight_streaming_ctx.device_budget} bytes out of {weight_streaming_ctx.total_device_budget}. mean latency = {mean_latency} ms"
174+
)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def compile(
9090
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
9191
use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING,
9292
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
93+
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
9394
**kwargs: Any,
9495
) -> torch.fx.GraphModule:
9596
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -162,6 +163,7 @@ def compile(
162163
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
163164
use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
164165
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
166+
enable_weight_streaming (bool): Enable weight streaming.
165167
**kwargs: Any,
166168
Returns:
167169
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -215,6 +217,10 @@ def compile(
215217
This flag inserts casts around matmul layers and ensures TensorRT executes the matmul layers in FP16 with FP32 accumulation."
216218
)
217219

220+
if enable_weight_streaming and not use_explicit_typing:
221+
raise AssertionError(
222+
"When enable_weight_streaming is enabled, it requires use_explicit_typing to be set to True"
223+
)
218224
# Aliasing inputs to arg_inputs for better understanding
219225
if not arg_inputs and not inputs:
220226
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
@@ -291,6 +297,7 @@ def compile(
291297
"reuse_cached_engines": reuse_cached_engines,
292298
"use_explicit_typing": use_explicit_typing,
293299
"use_fp32_acc": use_fp32_acc,
300+
"enable_weight_streaming": enable_weight_streaming,
294301
}
295302

296303
settings = CompilationSettings(**compilation_options)
@@ -549,6 +556,7 @@ def convert_exported_program_to_serialized_trt_engine(
549556
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
550557
use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING,
551558
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
559+
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
552560
**kwargs: Any,
553561
) -> bytes:
554562
"""Convert an ExportedProgram to a serialized TensorRT engine
@@ -609,6 +617,7 @@ def convert_exported_program_to_serialized_trt_engine(
609617
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
610618
use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
611619
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
620+
enable_weight_streaming (bool): Enable weight streaming.
612621
Returns:
613622
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
614623
"""
@@ -684,6 +693,7 @@ def convert_exported_program_to_serialized_trt_engine(
684693
"timing_cache_path": timing_cache_path,
685694
"use_explicit_typing": use_explicit_typing,
686695
"use_fp32_acc": use_fp32_acc,
696+
"enable_weight_streaming": enable_weight_streaming,
687697
}
688698

689699
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
CUSTOM_ENGINE_CACHE = None
4343
USE_EXPLICIT_TYPING = False
4444
USE_FP32_ACC = False
45+
ENABLE_WEIGHT_STREAMING = False
4546

4647

4748
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
DLA_SRAM_SIZE,
1515
DRYRUN,
1616
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
17+
ENABLE_WEIGHT_STREAMING,
1718
ENABLED_PRECISIONS,
1819
ENGINE_CAPABILITY,
1920
HARDWARE_COMPATIBLE,
@@ -82,6 +83,7 @@ class CompilationSettings:
8283
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
8384
use_strong_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
8485
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
86+
enable_weight_streaming (bool): Enable weight streaming.
8587
"""
8688

8789
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -118,6 +120,7 @@ class CompilationSettings:
118120
reuse_cached_engines: bool = REUSE_CACHED_ENGINES
119121
use_explicit_typing: bool = USE_EXPLICIT_TYPING
120122
use_fp32_acc: bool = USE_FP32_ACC
123+
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
121124

122125

123126
_SETTINGS_TO_BE_ENGINE_INVARIANT = (
@@ -130,6 +133,7 @@ class CompilationSettings:
130133
"make_refittable",
131134
"engine_capability",
132135
"hardware_compatible",
136+
"enable_weight_streaming",
133137
)
134138

135139

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,9 @@ def _populate_trt_builder_config(
305305
if tactic_sources is not None:
306306
builder_config.set_tactic_sources(tactic_sources=tactic_sources)
307307

308+
if self.compilation_settings.enable_weight_streaming:
309+
builder_config.set_flag(trt.BuilderFlag.WEIGHT_STREAMING)
310+
308311
return builder_config
309312

310313
def _create_timing_cache(

0 commit comments

Comments
 (0)