Skip to content

Commit 70e2a38

Browse files
authored
fix: Record cudagraphs when weight streaming budget has changed (#3309)
1 parent 36d249d commit 70e2a38

File tree

7 files changed

+252
-49
lines changed

7 files changed

+252
-49
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ TRTEngine::TRTEngine(
101101

102102
runtime_states.old_cudagraphs = CUDAGRAPHS_MODE;
103103
runtime_states.old_pre_allocated_outputs = false;
104+
runtime_states.context_changed = false;
104105

105106
if (_in_binding_names.size() == 0 && _out_binding_names.size() == 0) {
106107
uint64_t inputs = 0;
@@ -310,6 +311,9 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) {
310311
if (profile_execution) {
311312
enable_profiling();
312313
}
314+
// Indicates to reevaluate the runtime settings
315+
runtime_states.context_changed = true;
316+
313317
return result;
314318
}
315319

core/runtime/TRTEngine.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,25 +35,37 @@ struct TorchTRTRuntimeStates {
3535
bool old_cudagraphs;
3636
// Indicates whether pre-allocated output was enabled in the previous execute_engine
3737
bool old_pre_allocated_outputs;
38+
// Indicates whether context has changed
39+
bool context_changed;
3840

39-
// Evaluates whether certain conditions are met to enable CUDA Graph recording or to reuse pre-allocated outputs
41+
// Evaluates whether certain conditions are met to enable CUDA Graph recording/reset or to reuse pre-allocated outputs
4042
// based on the current and previous states, as well as input shape has changed
41-
std::tuple<bool, bool> set_runtime_states(bool new_cudagraphs, bool new_pre_allocated_output, bool shape_changed) {
43+
std::tuple<bool, bool, bool> set_runtime_states(
44+
bool new_cudagraphs,
45+
bool new_pre_allocated_output,
46+
bool shape_changed) {
4247
bool need_cudagraphs_record = false;
4348
bool can_use_pre_allocated_outputs = false;
49+
bool need_cudagraphs_reset = false;
4450

4551
// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
46-
if (new_cudagraphs && (!old_cudagraphs || shape_changed)) {
52+
if (new_cudagraphs && (!old_cudagraphs || shape_changed || context_changed)) {
4753
need_cudagraphs_record = true;
4854
}
4955
// Pre-allocated output can be used when previous and current state are true without shape change
5056
if (old_pre_allocated_outputs && new_pre_allocated_output && !shape_changed) {
5157
can_use_pre_allocated_outputs = true;
5258
}
59+
if (!new_cudagraphs || shape_changed || context_changed) {
60+
need_cudagraphs_reset = true;
61+
}
62+
5363
old_cudagraphs = new_cudagraphs;
5464
old_pre_allocated_outputs = new_pre_allocated_output;
65+
// Reset flag
66+
context_changed = false;
5567

56-
return {need_cudagraphs_record, can_use_pre_allocated_outputs};
68+
return {need_cudagraphs_record, can_use_pre_allocated_outputs, need_cudagraphs_reset};
5769
}
5870
};
5971

core/runtime/execute_engine.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,9 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
211211

212212
bool need_cudagraphs_record = std::get<0>(result);
213213
bool can_use_pre_allocated_outputs = std::get<1>(result);
214+
bool need_cudagraphs_reset = std::get<2>(result);
214215

215-
if (!cudagraphs_enabled || shape_changed) {
216+
if (need_cudagraphs_reset) {
216217
compiled_engine->cudagraph.reset();
217218
}
218219

py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ def __init__(
2727
super(CudaGraphsTorchTensorRTModule, self).__init__()
2828
self.compiled_module = compiled_module
2929
self.inputs = partitioning.construct_submodule_inputs(compiled_module)
30+
self.is_weight_streaming_set = False
3031

3132
self._input_buffers: List[torch.Tensor] = []
3233
self._output_buffers: List[torch.Tensor] = []
3334
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
3435
self.shape_key: Optional[str] = None
35-
self.prev_cudagraphs_enabled = False
3636
self._caller_stream: Optional[torch.cuda.Stream] = None
3737
self._engine_stream: Optional[torch.cuda.Stream] = None
3838
self.warm_up()
@@ -77,15 +77,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
7777
cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode()
7878
if cudagraphs_enabled:
7979
shape_changed = self.validate_input_shapes(inputs)
80-
# Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
81-
need_cudagraphs_record = not self.prev_cudagraphs_enabled or shape_changed
82-
self.prev_cudagraphs_enabled = cudagraphs_enabled
83-
80+
need_cudagraphs_record = shape_changed or self.is_weight_streaming_set
8481
if need_cudagraphs_record:
8582
if self.cudagraph:
8683
self.cudagraph.reset()
8784
self._input_buffers = [None] * len(self.inputs)
8885

86+
self.is_weight_streaming_set = False
8987
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
9088
contiguous_inputs: List[torch.Tensor] = [
9189
(

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,33 @@
2424

2525

2626
class TorchTRTRuntimeStates:
27-
def __init__(self, new_cudagraphs: bool, new_pre_allocated_output: bool):
27+
def __init__(self, new_cudagraphs: bool):
2828
# Indicates whether CUDAGraphs were enabled in the previous execute_engine
2929
self.old_cudagraphs = new_cudagraphs
3030
# Indicates whether pre-allocated output was enabled in the previous execute_engine
31-
self.old_pre_allocated_outputs = new_pre_allocated_output
31+
self.old_pre_allocated_outputs = False
32+
# Indicates whether context has changed
33+
self.context_changed = False
3234

3335
def set_runtime_states(
3436
self,
3537
new_cudagraphs: bool,
3638
new_pre_allocated_output: bool,
3739
shape_changed: bool,
38-
) -> Tuple[bool, bool]:
39-
# Evaluates whether certain conditions are met to enable CUDA Graph recording or to reuse pre-allocated outputs
40+
) -> Tuple[bool, bool, bool]:
41+
# Evaluates whether certain conditions are met to enable CUDA Graph recording or to use pre-allocated outputs
4042
# based on the current and previous states, as well as input shape has changed
4143
need_cudagraphs_record = False
4244
can_use_pre_allocated_outputs = False
43-
44-
# Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
45-
if new_cudagraphs and (not self.old_cudagraphs or shape_changed):
45+
need_cudagraphs_reset = False
46+
47+
# CUDA Graph recording is needed if CUDA graphs is enabled and:
48+
# - CUDA graphs were previously disabled
49+
# - or the shape has changed
50+
# - or the execution context has changed (e.g., weight streaming)
51+
if new_cudagraphs and (
52+
not self.old_cudagraphs or shape_changed or self.context_changed
53+
):
4654
need_cudagraphs_record = True
4755

4856
# Pre-allocated output can be used when previous and current state are true without shape change
@@ -53,10 +61,19 @@ def set_runtime_states(
5361
):
5462
can_use_pre_allocated_outputs = True
5563

64+
if not new_cudagraphs or shape_changed or self.context_changed:
65+
need_cudagraphs_reset = True
66+
5667
self.old_cudagraphs = new_cudagraphs
5768
self.old_pre_allocated_outputs = new_pre_allocated_output
69+
# reset flag
70+
self.context_changed = False
5871

59-
return need_cudagraphs_record, can_use_pre_allocated_outputs
72+
return (
73+
need_cudagraphs_record,
74+
can_use_pre_allocated_outputs,
75+
need_cudagraphs_reset,
76+
)
6077

6178

6279
class PythonTorchTensorRTModule(Module): # type: ignore[misc]
@@ -145,7 +162,7 @@ def __init__(
145162
self.weight_name_map = weight_name_map
146163
self.target_platform = Platform.current_platform()
147164
self.runtime_states = TorchTRTRuntimeStates(
148-
torch_tensorrt.runtime.get_cudagraphs_mode(), False
165+
torch_tensorrt.runtime.get_cudagraphs_mode()
149166
)
150167
self.pre_allocated_outputs: List[torch.Tensor] = []
151168
self.use_pre_allocated_outputs = False
@@ -168,6 +185,7 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
168185
del self.context
169186
budget_bytes = self._set_device_memory_budget(budget_bytes)
170187
self.context = self.engine.create_execution_context()
188+
self.runtime_states.context_changed = True
171189
return budget_bytes
172190

173191
def _set_device_memory_budget(self, budget_bytes: int) -> int:
@@ -200,7 +218,6 @@ def setup_engine(self) -> None:
200218
if self.settings.enable_weight_streaming:
201219
self.set_default_device_memory_budget()
202220
self.context = self.engine.create_execution_context()
203-
204221
assert self.engine.num_io_tensors == (
205222
len(self.input_names) + len(self.output_names)
206223
)
@@ -356,22 +373,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
356373

357374
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
358375
shape_changed = self.validate_input_shapes(inputs)
359-
need_cudagraphs_record, can_use_pre_allocated_outputs = (
360-
self.runtime_states.set_runtime_states(
361-
cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
362-
)
376+
(
377+
need_cudagraphs_record,
378+
can_use_pre_allocated_outputs,
379+
need_cudagraphs_reset,
380+
) = self.runtime_states.set_runtime_states(
381+
cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
363382
)
364383

384+
if need_cudagraphs_reset and self.cudagraph:
385+
self.cudagraph.reset()
386+
self.cudagraph = None
387+
365388
if need_cudagraphs_record:
366-
if self.cudagraph:
367-
self.cudagraph.reset()
368389
self._input_buffers = [None] * len(self.input_names)
369390
self._output_buffers = [None] * len(self.output_names)
370391

371-
if not cudagraphs_enabled and self.cudagraph:
372-
self.cudagraph.reset()
373-
self.cudagraph = None
374-
375392
# If in safe mode, check at each iteration for whether a switch is required
376393
if (
377394
torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE

py/torch_tensorrt/runtime/_weight_streaming.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ def __init__(
2020
) -> None:
2121
rt_mods = []
2222
self.current_device_budget = 0
23+
self.cuda_graphs_module = None
2324

2425
if isinstance(module, CudaGraphsTorchTensorRTModule):
26+
self.cuda_graphs_module = module
2527
module = module.compiled_module
2628
for name, rt_mod in module.named_children():
2729
if "_run_on_acc" in name and isinstance(
@@ -78,6 +80,8 @@ def _set_streamable_weight_bytes(self, requested_budget: int) -> int:
7880
ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size[i])
7981
logger.debug(f"Set weight streaming size {normalized_size[i]} for {name}")
8082

83+
if self.cuda_graphs_module:
84+
self.cuda_graphs_module.is_weight_streaming_set = True
8185
return ws_budget_bytes
8286

8387
def __setattr__(self, name: str, value: Any) -> None:

0 commit comments

Comments
 (0)