Skip to content

Commit 6b3097e

Browse files
committed
chore: rebase and rename variable
1 parent d2ba0a0 commit 6b3097e

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

core/runtime/TRTEngine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ struct TRTEngine : torch::CustomClassHolder {
103103
std::vector<at::Tensor> input_buffers = {};
104104
std::vector<at::Tensor> output_buffers = {};
105105
std::string shape_key = "None";
106-
bool cudagraphs_enabled = false;
106+
bool prev_cudagraphs_enabled = false;
107107
bool use_pre_allocated_outputs = true;
108108
std::vector<at::Tensor> pre_allocated_outputs;
109109

core/runtime/execute_engine.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
137137
// Whether cudagraphs needs to record the graph on this pass
138138
// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
139139
bool need_cudagraphs_record =
140-
(((!compiled_engine->cudagraphs_enabled) && CUDAGRAPHS_MODE) || (CUDAGRAPHS_MODE && shape_changed));
141-
compiled_engine->cudagraphs_enabled = CUDAGRAPHS_MODE;
140+
(((!compiled_engine->prev_cudagraphs_enabled) && CUDAGRAPHS_MODE) || (CUDAGRAPHS_MODE && shape_changed));
141+
compiled_engine->prev_cudagraphs_enabled = CUDAGRAPHS_MODE;
142142

143143
if (!CUDAGRAPHS_MODE || shape_changed) {
144144
compiled_engine->cudagraph.reset();

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(
108108
self.engine = None
109109
self.weight_name_map = weight_name_map
110110
self.target_platform = Platform.current_platform()
111-
self.cudagraphs_enabled = False
111+
self.prev_cudagraphs_enabled = False
112112
self.pre_allocated_outputs: List[torch.Tensor] = []
113113
self.use_pre_allocated_outputs = True
114114

@@ -269,11 +269,11 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
269269
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
270270
shape_changed = self.validate_input_shapes(inputs)
271271
# Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
272-
if not self.cudagraphs_enabled and cudagraphs_enabled:
272+
if not self.prev_cudagraphs_enabled and cudagraphs_enabled:
273273
need_cudagraphs_record = True
274274
else:
275275
need_cudagraphs_record = cudagraphs_enabled and shape_changed
276-
self.cudagraphs_enabled = cudagraphs_enabled
276+
self.prev_cudagraphs_enabled = cudagraphs_enabled
277277

278278
if need_cudagraphs_record:
279279
if self.cudagraph:

0 commit comments

Comments
 (0)