24
24
25
25
26
26
class TorchTRTRuntimeStates :
27
- def __init__ (self , new_cudagraphs : bool , new_pre_allocated_output : bool ):
27
+ def __init__ (self , new_cudagraphs : bool ):
28
28
# Indicates whether CUDAGraphs were enabled in the previous execute_engine
29
29
self .old_cudagraphs = new_cudagraphs
30
30
# Indicates whether pre-allocated output was enabled in the previous execute_engine
31
- self .old_pre_allocated_outputs = new_pre_allocated_output
31
+ self .old_pre_allocated_outputs = False
32
+ # Indicates whether context has changed
33
+ self .context_changed = False
32
34
33
35
def set_runtime_states (
34
36
self ,
35
37
new_cudagraphs : bool ,
36
38
new_pre_allocated_output : bool ,
37
39
shape_changed : bool ,
38
- ) -> Tuple [bool , bool ]:
39
- # Evaluates whether certain conditions are met to enable CUDA Graph recording or to reuse pre-allocated outputs
40
+ ) -> Tuple [bool , bool , bool ]:
41
+ # Evaluates whether certain conditions are met to enable CUDA Graph recording or to use pre-allocated outputs
40
42
# based on the current and previous states, as well as input shape has changed
41
43
need_cudagraphs_record = False
42
44
can_use_pre_allocated_outputs = False
43
-
44
- # Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
45
- if new_cudagraphs and (not self .old_cudagraphs or shape_changed ):
45
+ need_cudagraphs_reset = False
46
+
47
+ # CUDA Graph recording is needed if CUDA graphs is enabled and:
48
+ # - CUDA graphs were previously disabled
49
+ # - or the shape has changed
50
+ # - or the execution context has changed (e.g., weight streaming)
51
+ if new_cudagraphs and (
52
+ not self .old_cudagraphs or shape_changed or self .context_changed
53
+ ):
46
54
need_cudagraphs_record = True
47
55
48
56
# Pre-allocated output can be used when previous and current state are true without shape change
@@ -53,10 +61,19 @@ def set_runtime_states(
53
61
):
54
62
can_use_pre_allocated_outputs = True
55
63
64
+ if not new_cudagraphs or shape_changed or self .context_changed :
65
+ need_cudagraphs_reset = True
66
+
56
67
self .old_cudagraphs = new_cudagraphs
57
68
self .old_pre_allocated_outputs = new_pre_allocated_output
69
+ # reset flag
70
+ self .context_changed = False
58
71
59
- return need_cudagraphs_record , can_use_pre_allocated_outputs
72
+ return (
73
+ need_cudagraphs_record ,
74
+ can_use_pre_allocated_outputs ,
75
+ need_cudagraphs_reset ,
76
+ )
60
77
61
78
62
79
class PythonTorchTensorRTModule (Module ): # type: ignore[misc]
@@ -145,7 +162,7 @@ def __init__(
145
162
self .weight_name_map = weight_name_map
146
163
self .target_platform = Platform .current_platform ()
147
164
self .runtime_states = TorchTRTRuntimeStates (
148
- torch_tensorrt .runtime .get_cudagraphs_mode (), False
165
+ torch_tensorrt .runtime .get_cudagraphs_mode ()
149
166
)
150
167
self .pre_allocated_outputs : List [torch .Tensor ] = []
151
168
self .use_pre_allocated_outputs = False
@@ -168,6 +185,7 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
168
185
del self .context
169
186
budget_bytes = self ._set_device_memory_budget (budget_bytes )
170
187
self .context = self .engine .create_execution_context ()
188
+ self .runtime_states .context_changed = True
171
189
return budget_bytes
172
190
173
191
def _set_device_memory_budget (self , budget_bytes : int ) -> int :
@@ -200,7 +218,6 @@ def setup_engine(self) -> None:
200
218
if self .settings .enable_weight_streaming :
201
219
self .set_default_device_memory_budget ()
202
220
self .context = self .engine .create_execution_context ()
203
-
204
221
assert self .engine .num_io_tensors == (
205
222
len (self .input_names ) + len (self .output_names )
206
223
)
@@ -356,22 +373,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
356
373
357
374
cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
358
375
shape_changed = self .validate_input_shapes (inputs )
359
- need_cudagraphs_record , can_use_pre_allocated_outputs = (
360
- self .runtime_states .set_runtime_states (
361
- cudagraphs_enabled , self .use_pre_allocated_outputs , shape_changed
362
- )
376
+ (
377
+ need_cudagraphs_record ,
378
+ can_use_pre_allocated_outputs ,
379
+ need_cudagraphs_reset ,
380
+ ) = self .runtime_states .set_runtime_states (
381
+ cudagraphs_enabled , self .use_pre_allocated_outputs , shape_changed
363
382
)
364
383
384
+ if need_cudagraphs_reset and self .cudagraph :
385
+ self .cudagraph .reset ()
386
+ self .cudagraph = None
387
+
365
388
if need_cudagraphs_record :
366
- if self .cudagraph :
367
- self .cudagraph .reset ()
368
389
self ._input_buffers = [None ] * len (self .input_names )
369
390
self ._output_buffers = [None ] * len (self .output_names )
370
391
371
- if not cudagraphs_enabled and self .cudagraph :
372
- self .cudagraph .reset ()
373
- self .cudagraph = None
374
-
375
392
# If in safe mode, check at each iteration for whether a switch is required
376
393
if (
377
394
torch_tensorrt .runtime ._multi_device_safe_mode ._PY_RT_MULTI_DEVICE_SAFE_MODE
0 commit comments