@@ -141,11 +141,15 @@ def __init__(
141
141
self .engine = None
142
142
self .weight_name_map = weight_name_map
143
143
self .target_platform = Platform .current_platform ()
144
+ < << << << HEAD
144
145
self .runtime_states = TorchTRTRuntimeStates (
145
146
torch_tensorrt .runtime .get_cudagraphs_mode (), False
146
147
)
147
148
self .pre_allocated_outputs : List [torch .Tensor ] = []
148
149
self .use_pre_allocated_outputs = False
150
+ == == == =
151
+ self .has_context_changed = False
152
+ >> >> >> > 7 bb66dac4 (fix : Record cudagraphs when weight streaming budget has changed )
149
153
150
154
if self .serialized_engine is not None and not self .settings .lazy_engine_init :
151
155
self .setup_engine ()
@@ -165,6 +169,9 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
165
169
del self .context
166
170
budget_bytes = self ._set_device_memory_budget (budget_bytes )
167
171
self .context = self .engine .create_execution_context ()
172
+ # Indicates to reevaluate the runtime settings
173
+ self .has_context_changed = True
174
+
168
175
return budget_bytes
169
176
170
177
def _set_device_memory_budget (self , budget_bytes : int ) -> int :
@@ -353,11 +360,16 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
353
360
self ._check_initialized ()
354
361
355
362
cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
363
+ < << << << HEAD
356
364
shape_changed = self .validate_input_shapes (inputs )
357
365
need_cudagraphs_record , can_use_pre_allocated_outputs = (
358
366
self .runtime_states .validate_states (
359
367
cudagraphs_enabled , self .use_pre_allocated_outputs , shape_changed
360
368
)
369
+ == == == =
370
+ need_cudagraphs_record = cudagraphs_enabled and (
371
+ not self .cudagraphs_validate_shapes (inputs ) or self .has_context_changed
372
+ >> > >> >> 7 bb66dac4 (fix : Record cudagraphs when weight streaming budget has changed )
361
373
)
362
374
363
375
if need_cudagraphs_record :
@@ -366,11 +378,18 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
366
378
self ._input_buffers = [None ] * len (self .input_names )
367
379
self ._output_buffers = [None ] * len (self .output_names )
368
380
369
- if not cudagraphs_enabled and self .cudagraph :
381
+ if self . cudagraph and ( not cudagraphs_enabled or self .has_context_changed ) :
370
382
self .cudagraph .reset ()
371
383
self .cudagraph = None
372
384
385
+ << << < << HEAD
373
386
# If in safe mode, check at each iteration for whether a switch is required
387
+ == == == =
388
+ # Reset the flag
389
+ self .has_context_changed = False
390
+
391
+ # If in safe mode, check at each iteration for for whether a switch is required
392
+ >> >> > >> 7 bb66dac4 (fix : Record cudagraphs when weight streaming budget has changed )
374
393
if (
375
394
torch_tensorrt .runtime ._multi_device_safe_mode ._PY_RT_MULTI_DEVICE_SAFE_MODE
376
395
):
0 commit comments