Skip to content

Commit 7c5123a

Browse files
committed
fix: Record cudagraphs when weight streaming budget has changed
1 parent ada9156 commit 7c5123a

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,9 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) {
310310
if (profile_execution) {
311311
enable_profiling();
312312
}
313+
// Indicates to reevaluate the runtime settings
314+
has_context_changed = true;
315+
313316
return result;
314317
}
315318

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,15 @@ def __init__(
141141
self.engine = None
142142
self.weight_name_map = weight_name_map
143143
self.target_platform = Platform.current_platform()
144+
<<<<<<< HEAD
144145
self.runtime_states = TorchTRTRuntimeStates(
145146
torch_tensorrt.runtime.get_cudagraphs_mode(), False
146147
)
147148
self.pre_allocated_outputs: List[torch.Tensor] = []
148149
self.use_pre_allocated_outputs = False
150+
=======
151+
self.has_context_changed = False
152+
>>>>>>> 7bb66dac4 (fix: Record cudagraphs when weight streaming budget has changed)
149153

150154
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
151155
self.setup_engine()
@@ -165,6 +169,9 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
165169
del self.context
166170
budget_bytes = self._set_device_memory_budget(budget_bytes)
167171
self.context = self.engine.create_execution_context()
172+
# Indicates to reevaluate the runtime settings
173+
self.has_context_changed = True
174+
168175
return budget_bytes
169176

170177
def _set_device_memory_budget(self, budget_bytes: int) -> int:
@@ -353,11 +360,16 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
353360
self._check_initialized()
354361

355362
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
363+
<<<<<<< HEAD
356364
shape_changed = self.validate_input_shapes(inputs)
357365
need_cudagraphs_record, can_use_pre_allocated_outputs = (
358366
self.runtime_states.validate_states(
359367
cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
360368
)
369+
=======
370+
need_cudagraphs_record = cudagraphs_enabled and (
371+
not self.cudagraphs_validate_shapes(inputs) or self.has_context_changed
372+
>>>>>>> 7bb66dac4 (fix: Record cudagraphs when weight streaming budget has changed)
361373
)
362374

363375
if need_cudagraphs_record:
@@ -366,11 +378,18 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
366378
self._input_buffers = [None] * len(self.input_names)
367379
self._output_buffers = [None] * len(self.output_names)
368380

369-
if not cudagraphs_enabled and self.cudagraph:
381+
if self.cudagraph and (not cudagraphs_enabled or self.has_context_changed):
370382
self.cudagraph.reset()
371383
self.cudagraph = None
372384

385+
<<<<<<< HEAD
373386
# If in safe mode, check at each iteration for whether a switch is required
387+
=======
388+
# Reset the flag
389+
self.has_context_changed = False
390+
391+
# If in safe mode, check at each iteration for for whether a switch is required
392+
>>>>>>> 7bb66dac4 (fix: Record cudagraphs when weight streaming budget has changed)
374393
if (
375394
torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
376395
):

0 commit comments

Comments
 (0)