Skip to content

Commit 4a5f0d1

Browse files
committed
chore: rebase and rename variable
1 parent 210ae8b commit 4a5f0d1

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

core/runtime/TRTEngine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ struct TRTEngine : torch::CustomClassHolder {
103103
std::vector<at::Tensor> input_buffers = {};
104104
std::vector<at::Tensor> output_buffers = {};
105105
std::string shape_key = "None";
106-
bool cudagraphs_enabled = false;
106+
bool prev_cudagraphs_enabled = false;
107107
bool use_pre_allocated_outputs = true;
108108
std::vector<at::Tensor> pre_allocated_outputs;
109109

core/runtime/execute_engine.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
135135
// Whether cudagraphs needs to record the graph on this pass
136136
// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
137137
bool need_cudagraphs_record =
138-
(((!compiled_engine->cudagraphs_enabled) && CUDAGRAPHS_MODE) || (CUDAGRAPHS_MODE && shape_changed));
139-
compiled_engine->cudagraphs_enabled = CUDAGRAPHS_MODE;
138+
(((!compiled_engine->prev_cudagraphs_enabled) && CUDAGRAPHS_MODE) || (CUDAGRAPHS_MODE && shape_changed));
139+
compiled_engine->prev_cudagraphs_enabled = CUDAGRAPHS_MODE;
140140

141141
if (!CUDAGRAPHS_MODE || shape_changed) {
142142
compiled_engine->cudagraph.reset();

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def __init__(
107107
self.engine = None
108108
self.weight_name_map = weight_name_map
109109
self.target_platform = Platform.current_platform()
110-
self.cudagraphs_enabled = False
110+
self.prev_cudagraphs_enabled = False
111111
self.pre_allocated_outputs: List[torch.Tensor] = []
112112
self.use_pre_allocated_outputs = True
113113

@@ -268,11 +268,11 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
268268
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
269269
shape_changed = self.validate_input_shapes(inputs)
270270
# Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
271-
if not self.cudagraphs_enabled and cudagraphs_enabled:
271+
if not self.prev_cudagraphs_enabled and cudagraphs_enabled:
272272
need_cudagraphs_record = True
273273
else:
274274
need_cudagraphs_record = cudagraphs_enabled and shape_changed
275-
self.cudagraphs_enabled = cudagraphs_enabled
275+
self.prev_cudagraphs_enabled = cudagraphs_enabled
276276

277277
if need_cudagraphs_record:
278278
if self.cudagraph:

0 commit comments

Comments
 (0)