Skip to content

Commit 7bb66da

Browse files
committed
fix: Record cudagraphs when weight streaming budget has changed
1 parent 1b40bcc commit 7bb66da

File tree

4 files changed

+20
-5
lines changed

4 files changed

+20
-5
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,9 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) {
307307
if (profile_execution) {
308308
enable_profiling();
309309
}
310+
// Indicates to reevaluate the runtime settings
311+
has_context_changed = true;
312+
310313
return result;
311314
}
312315

core/runtime/TRTEngine.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ struct TRTEngine : torch::CustomClassHolder {
102102
std::vector<at::Tensor> input_buffers = {};
103103
std::vector<at::Tensor> output_buffers = {};
104104
std::string shape_key;
105+
bool has_context_changed = false;
105106

106107
// TODO: Implement a call method
107108
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);

core/runtime/execute_engine.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,15 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
115115
}
116116

117117
// Whether cudagraphs needs to record the graph on this pass
118-
bool need_cudagraphs_record = (CUDAGRAPHS_MODE && (!_cudagraphs_validate_shapes(inputs, compiled_engine)));
118+
bool need_cudagraphs_record =
119+
(CUDAGRAPHS_MODE &&
120+
(!_cudagraphs_validate_shapes(inputs, compiled_engine) || compiled_engine->has_context_changed));
119121

120-
if (!CUDAGRAPHS_MODE) {
122+
if (!CUDAGRAPHS_MODE || compiled_engine->has_context_changed) {
121123
compiled_engine->cudagraph.reset();
122124
}
125+
// Reset the flag
126+
compiled_engine->has_context_changed = false;
123127

124128
// this is a buffer to store shape tensor input addresses throughout the runtime scope
125129
std::list<std::vector<int64_t>> inputShapeTensorValues;

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def __init__(
107107
self.engine = None
108108
self.weight_name_map = weight_name_map
109109
self.target_platform = Platform.current_platform()
110+
self.has_context_changed = False
110111

111112
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
112113
self.setup_engine()
@@ -126,6 +127,9 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
126127
del self.context
127128
budget_bytes = self._set_device_memory_budget(budget_bytes)
128129
self.context = self.engine.create_execution_context()
130+
# Indicates to reevaluate the runtime settings
131+
self.has_context_changed = True
132+
129133
return budget_bytes
130134

131135
def _set_device_memory_budget(self, budget_bytes: int) -> int:
@@ -247,18 +251,21 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
247251
self._check_initialized()
248252

249253
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
250-
need_cudagraphs_record = (
251-
cudagraphs_enabled and not self.cudagraphs_validate_shapes(inputs)
254+
need_cudagraphs_record = cudagraphs_enabled and (
255+
not self.cudagraphs_validate_shapes(inputs) or self.has_context_changed
252256
)
253257

254258
if need_cudagraphs_record:
255259
self._input_buffers = [None] * len(self.input_names)
256260
self._output_buffers = [None] * len(self.output_names)
257261

258-
if not cudagraphs_enabled and self.cudagraph:
262+
if self.cudagraph and (not cudagraphs_enabled or self.has_context_changed):
259263
self.cudagraph.reset()
260264
self.cudagraph = None
261265

266+
# Reset the flag
267+
self.has_context_changed = False
268+
262269
# If in safe mode, check at each iteration for for whether a switch is required
263270
if (
264271
torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE

0 commit comments

Comments
 (0)